Code for Visual Question Answering with Transformers in Python Tutorial


View on Github

running_blip_git.py

# -*- coding: utf-8 -*-
"""VisualQuestionAnswering_PythonCodeTutorial.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1dM89DgL_hg4K3uiKnTQ-p8rtS05wH_fX
"""

!pip install -qU transformers

"""# BLIP

- https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip/modeling_blip.py
- https://huggingface.co/Salesforce/blip-vqa-base/tree/main
"""

import requests
from PIL import Image
from transformers import BlipProcessor, BlipForQuestionAnswering
import torch

# load the image we will test BLIP on
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
image

# load necessary components: the processor and the model
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

def get_answer_blip(model, processor, image, question):
    """Answers the given question and handles all the preprocessing and postprocessing steps"""
    # preprocess the given image and question
    inputs = processor(image, question, return_tensors="pt")
    # generate the answer (get output)
    out = model.generate(**inputs)
    # post-process the output to get human friendly english text
    print(processor.decode(out[0], skip_special_tokens=True))
    return

# sample question 1
question = "how many dogs are in the picture?"
get_answer_blip(model, processor, image, question)

# sample question 2
question = "how will you describe the picture?"
get_answer_blip(model, processor, image, question)

# sample question 3
question = "where are they?"
get_answer_blip(model, processor, image, question)

# sample question 4
question = "What are they doing?"
get_answer_blip(model, processor, image, question)

# sample question 5
question = "What the dog is wearing?"
get_answer_blip(model, processor, image, question)

class BLIP_VQA:
    """Custom implementation of the BLIP model. The code has been adapted from the official transformers implementation"""

    def __init__(self, vision_model, text_encoder, text_decoder, processor):
        """Initialize various objects"""
        self.vision_model = vision_model
        self.text_encoder = text_encoder
        self.text_decoder = text_decoder
        self.processor = processor

    def preprocess(self, img, ques):
        """preprocess the inputs: image, question"""
        # preprocess using the processor
        inputs = self.processor(img, ques, return_tensors='pt')
        # store the pixel values of the image, input IDs (i.e., token IDs) of the question and the attention masks separately
        pixel_values = inputs['pixel_values']
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        return pixel_values, input_ids, attention_mask


    def generate_output(self, pixel_values, input_ids, attention_mask):
        """Generates output from the preprocessed input"""

        # get the vision outputs (i.e., the image embeds)
        vision_outputs = self.vision_model(pixel_values=pixel_values)
        img_embeds = vision_outputs[0]

        # create attention mask with 1s on all the image embedding positions
        img_attention_mask = torch.ones(img_embeds.size()[: -1], dtype=torch.long)

        # encode the questions
        question_outputs = self.text_encoder(input_ids=input_ids,
                                             attention_mask=attention_mask,
                                             encoder_hidden_states=img_embeds,
                                             encoder_attention_mask=img_attention_mask,
                                             return_dict=False)

        # create attention mask with 1s on all the question token IDs positions
        question_embeds = question_outputs[0]
        question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long)

        # initialize the answers with the beginning-of-sentence IDs (bos ID)
        bos_ids = torch.full((question_embeds.size(0), 1), fill_value=30522)

        # get output from the decoder. These outputs are the generated IDs
        outputs = self.text_decoder.generate(
            input_ids=bos_ids,
            eos_token_id=102,
            pad_token_id=0,
            encoder_hidden_states=question_embeds,
            encoder_attention_mask=question_attention_mask)

        return outputs


    def postprocess(self, outputs):
        """post-process the output generated by the text-decoder"""

        return self.processor.decode(outputs[0], skip_special_tokens=True)


    def get_answer(self, image, ques):
        """Returns human friendly answer to a question"""

        # preprocess
        pixel_values, input_ids, attention_mask = self.preprocess(image, ques)
        # generate output
        outputs = self.generate_output(pixel_values, input_ids, attention_mask)
        # post-process
        answer = self.postprocess(outputs)
        return answer

blip_vqa = BLIP_VQA(vision_model=model.vision_model,
                    text_encoder=model.text_encoder,
                    text_decoder=model.text_decoder,
                    processor=processor)

# sample question 1
ques = "how will you describe the picture?"
print(blip_vqa.get_answer(image, ques))

# load another image to test BLIP
img_url = "https://fastly.picsum.photos/id/11/200/200.jpg?hmac=LBGO0uEpEmAVS8NeUXMqxcIdHGIcu0JiOb5DJr4mtUI"
image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
image

# sample question 1
ques = "Describe the picture"
print(blip_vqa.get_answer(image, ques))

# sample question 2
ques = "What is the major color present?"
print(blip_vqa.get_answer(image, ques))

# sample question 3
ques = "How's the weather?"
print(blip_vqa.get_answer(image, ques))

"""# GIT

- https://github.com/huggingface/transformers/blob/main/src/transformers/models/git/modeling_git.py
- https://huggingface.co/microsoft/git-base-textvqa
"""

!pip install -qU transformers

from transformers import AutoProcessor, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
from PIL import Image

# load the image we will test GIT on
file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
image = Image.open(file_path).convert("RGB")
image

# load necessary components: the processor and the model
processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")

class GIT_VQA:
    """Custom implementation of the GIT model for Visual Question Answering (VQA) tasks."""

    def __init__(self, model, processor):
        """Initializes the model and the processor."""
        self.model = model
        self.processor = processor
        return


    def preprocess(self, image, question):
        """Preprocesses the inputs: image, question"""
        # process the image to get pixel values
        pixel_values = self.processor(images=image, return_tensors="pt").pixel_values

        # process the question to get input IDs, but do not add special tokens
        input_ids = self.processor(text=question, add_special_tokens=False).input_ids

        # add the CLS token at the beginning of the input_ids and format for model input
        input_ids = [self.processor.tokenizer.cls_token_id] + input_ids
        input_ids = torch.tensor(input_ids).unsqueeze(0)

        return pixel_values, input_ids


    def generate(self, pixel_values, input_ids):
        """Generates the output from the preprocessed inputs."""

        # generate output using the model with a maximum length of 50 tokens
        outputs = self.model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
        return outputs


    def postprocess(self, outputs):
        """Post-processes the output generated by the model."""

        # decode the output, ignoring special tokens
        answer = self.processor.batch_decode(outputs, skip_special_tokens=True)
        return answer


    def get_answer(self, image, question):
        """Returns human friendly answer to a question"""

        # preprocess
        pixel_values, input_ids = self.preprocess(image, question)
        # generate output
        outputs = self.generate(pixel_values, input_ids)
        # post-process
        answer = self.postprocess(outputs)
        return answer

# create a GIT instance
git_vqa = GIT_VQA(model=model, processor=processor)

# sample question 1
question = "what does the front of the bus say at the top?"
answer = git_vqa.get_answer(image, question)
print(answer)

# sample question 2
question = "what are all the colors present on the bus?"
answer = git_vqa.get_answer(image, question)
print(answer)

# sample question 3
question = "How many wheels you see in the bus?"
answer = git_vqa.get_answer(image, question)
print(answer)

# load another image to test BLIP
img_url = "https://fastly.picsum.photos/id/110/500/500.jpg?hmac=wSHhLFNyJ6k3uM94s6etGQ0WWhmwbdUSiZ9ZDL5Hh2Q"
image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
image

# sample question 1
question = "Is it night in the image?"
answer = git_vqa.get_answer(image, question)
print(answer)

running_blip2.py

# %%
!pip install transformers accelerate

# %%
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import os

device = torch.device("cuda", 0)
device

# %%
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

# %%
model.to(device)

# %%
import urllib.parse as parse
import os

# a function to determine whether a string is a URL or not
def is_url(string):
    try:
        result = parse.urlparse(string)
        return all([result.scheme, result.netloc, result.path])
    except:
        return False
    
# a function to load an image
def load_image(image_path):
    if is_url(image_path):
        return Image.open(requests.get(image_path, stream=True).raw)
    elif os.path.exists(image_path):
        return Image.open(image_path)

# %%
raw_image = load_image("http://images.cocodataset.org/test-stuff2017/000000007226.jpg")

# %%
question = "a"
inputs = processor(raw_image, question, return_tensors="pt").to(device, dtype=torch.float16)

# %%
out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True))

# %%
question = "a vintage car driving down a street"
inputs = processor(raw_image, question, return_tensors="pt").to(device, dtype=torch.float16)

# %%
out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True))

# %%
question = "Question: What is the estimated year of these cars? Answer:"
inputs = processor(raw_image, question, return_tensors="pt").to(device, dtype=torch.float16)

# %%
out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True))

# %%
question = "Question: What is the color of the car? Answer:"
inputs = processor(raw_image, question, return_tensors="pt").to(device, dtype=torch.float16)

# %%
out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True))

# %%