finetuning_vit_for_image_classification.py
# %%
!pip install transformers evaluate datasets
# %%
import requests
import torch
from PIL import Image
from transformers import *
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
# the model name
model_name = "google/vit-base-patch16-224"
# load the image processor
image_processor = ViTImageProcessor.from_pretrained(model_name)
# loading the pre-trained model
model = ViTForImageClassification.from_pretrained(model_name).to(device)
# %%
import urllib.parse as parse
import os
# a function to determine whether a string is a URL or not
def is_url(string):
try:
result = parse.urlparse(string)
return all([result.scheme, result.netloc, result.path])
except:
return False
# a function to load an image
def load_image(image_path):
if is_url(image_path):
return Image.open(requests.get(image_path, stream=True).raw)
elif os.path.exists(image_path):
return Image.open(image_path)
# %%
def get_prediction(model, url_or_path):
# load the image
img = load_image(url_or_path)
# preprocessing the image
pixel_values = image_processor(img, return_tensors="pt")["pixel_values"].to(device)
# perform inference
output = model(pixel_values)
# get the label id and return the class name
return model.config.id2label[int(output.logits.softmax(dim=1).argmax())]
# %%
get_prediction(model, "http://images.cocodataset.org/test-stuff2017/000000000128.jpg")
# %% [markdown]
# # Loading our Dataset
# %%
from datasets import load_dataset
# download & load the dataset
ds = load_dataset("food101")
# %% [markdown]
# ## Loading a Custom Dataset using `ImageFolder`
# Run the three below cells to load a custom dataset (that's not in the Hub) using `ImageFolder`
# %%
import requests
from tqdm import tqdm
def get_file(url):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
filename = None
content_disposition = response.headers.get('content-disposition')
if content_disposition:
parts = content_disposition.split(';')
for part in parts:
if 'filename' in part:
filename = part.split('=')[1].strip('"')
if not filename:
filename = os.path.basename(url)
block_size = 1024 # 1 Kibibyte
tqdm_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
with open(filename, 'wb') as file:
for data in response.iter_content(block_size):
tqdm_bar.update(len(data))
file.write(data)
tqdm_bar.close()
print(f"Downloaded {filename} ({total_size} bytes)")
return filename
# %%
import zipfile
import os
def download_and_extract_dataset():
# dataset from https://github.com/udacity/dermatologist-ai
# 5.3GB
train_url = "https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/train.zip"
# 824.5MB
valid_url = "https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/valid.zip"
# 5.1GB
test_url = "https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/test.zip"
for i, download_link in enumerate([valid_url, train_url, test_url]):
data_dir = get_file(download_link)
print("Extracting", download_link)
with zipfile.ZipFile(data_dir, "r") as z:
z.extractall("data")
# remove the temp file
os.remove(data_dir)
# comment the below line if you already downloaded the dataset
download_and_extract_dataset()
# %%
from datasets import load_dataset
# load the custom dataset
ds = load_dataset("imagefolder", data_dir="data")
# %% [markdown]
# # Exploring the Data
# %%
ds
# %%
labels = ds["train"].features["label"]
labels
# %%
labels.int2str(ds["train"][532]["label"])
# %%
import random
import matplotlib.pyplot as plt
def show_image_grid(dataset, split, grid_size=(4,4)):
# Select random images from the given split
indices = random.sample(range(len(dataset[split])), grid_size[0]*grid_size[1])
images = [dataset[split][i]["image"] for i in indices]
labels = [dataset[split][i]["label"] for i in indices]
# Display the images in a grid
fig, axes = plt.subplots(nrows=grid_size[0], ncols=grid_size[1], figsize=(8,8))
for i, ax in enumerate(axes.flat):
ax.imshow(images[i])
ax.axis('off')
ax.set_title(ds["train"].features["label"].int2str(labels[i]))
plt.show()
# %%
show_image_grid(ds, "train")
# %% [markdown]
# # Preprocessing the Data
# %%
def transform(examples):
# convert all images to RGB format, then preprocessing it
# using our image processor
inputs = image_processor([img.convert("RGB") for img in examples["image"]], return_tensors="pt")
# we also shouldn't forget about the labels
inputs["labels"] = examples["label"]
return inputs
# %%
# use the with_transform() method to apply the transform to the dataset on the fly during training
dataset = ds.with_transform(transform)
# %%
for item in dataset["train"]:
print(item["pixel_values"].shape)
print(item["labels"])
break
# %%
# extract the labels for our dataset
labels = ds["train"].features["label"].names
labels
# %%
import torch
def collate_fn(batch):
return {
"pixel_values": torch.stack([x["pixel_values"] for x in batch]),
"labels": torch.tensor([x["labels"] for x in batch]),
}
# %% [markdown]
# # Defining the Metrics
# %%
from evaluate import load
import numpy as np
# load the accuracy and f1 metrics from the evaluate module
accuracy = load("accuracy")
f1 = load("f1")
def compute_metrics(eval_pred):
# compute the accuracy and f1 scores & return them
accuracy_score = accuracy.compute(predictions=np.argmax(eval_pred.predictions, axis=1), references=eval_pred.label_ids)
f1_score = f1.compute(predictions=np.argmax(eval_pred.predictions, axis=1), references=eval_pred.label_ids, average="macro")
return {**accuracy_score, **f1_score}
# %% [markdown]
# # Training the Model
# %%
# load the ViT model
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)},
ignore_mismatched_sizes=True,
)
# %%
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./vit-base-food", # output directory
# output_dir="./vit-base-skin-cancer",
per_device_train_batch_size=32, # batch size per device during training
evaluation_strategy="steps", # evaluation strategy to adopt during training
num_train_epochs=3, # total number of training epochs
# fp16=True, # use mixed precision
save_steps=1000, # number of update steps before saving checkpoint
eval_steps=1000, # number of update steps before evaluating
logging_steps=1000, # number of update steps before logging
# save_steps=50,
# eval_steps=50,
# logging_steps=50,
save_total_limit=2, # limit the total amount of checkpoints on disk
remove_unused_columns=False, # remove unused columns from the dataset
push_to_hub=False, # do not push the model to the hub
report_to='tensorboard', # report metrics to tensorboard
load_best_model_at_end=True, # load the best model at the end of training
)
# %%
from transformers import Trainer
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
data_collator=collate_fn, # the data collator that will be used for batching
compute_metrics=compute_metrics, # the metrics function that will be used for evaluation
train_dataset=dataset["train"], # training dataset
eval_dataset=dataset["validation"], # evaluation dataset
tokenizer=image_processor, # the processor that will be used for preprocessing the images
)
# %%
# start training
trainer.train()
# %%
# trainer.evaluate(dataset["test"])
trainer.evaluate()
# %%
# start tensorboard
# %load_ext tensorboard
%reload_ext tensorboard
%tensorboard --logdir ./vit-base-food/runs
# %% [markdown]
# ## Alternatively: Training using PyTorch Loop
# Run the two below cells to fine-tune using a regular PyTorch loop if you want.
# %%
# Training loop
from torch.utils.tensorboard import SummaryWriter
from torch.optim import AdamW
from torch.utils.data import DataLoader
batch_size = 32
train_dataset_loader = DataLoader(dataset["train"], collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = DataLoader(dataset["validation"], collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
# define the optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)
log_dir = "./image-classification/tensorboard"
summary_writer = SummaryWriter(log_dir=log_dir)
num_epochs = 3
model = model.to(device)
# print some statistics before training
# number of training steps
n_train_steps = num_epochs * len(train_dataset_loader)
# number of validation steps
n_valid_steps = len(valid_dataset_loader)
# current training step
current_step = 0
# logging, eval & save steps
save_steps = 1000
def compute_metrics(eval_pred):
accuracy_score = accuracy.compute(predictions=eval_pred.predictions, references=eval_pred.label_ids)
f1_score = f1.compute(predictions=eval_pred.predictions, references=eval_pred.label_ids, average="macro")
return {**accuracy_score, **f1_score}
# %%
for epoch in range(num_epochs):
# set the model to training mode
model.train()
# initialize the training loss
train_loss = 0
# initialize the progress bar
progress_bar = tqdm(range(current_step, n_train_steps), "Training", dynamic_ncols=True, ncols=80)
for batch in train_dataset_loader:
if (current_step+1) % save_steps == 0:
### evaluation code ###
# evaluate on the validation set
# if the current step is a multiple of the save steps
print()
print(f"Validation at step {current_step}...")
print()
# set the model to evaluation mode
model.eval()
# initialize our lists that store the predictions and the labels
predictions, labels = [], []
# initialize the validation loss
valid_loss = 0
for batch in valid_dataset_loader:
# get the batch
pixel_values = batch["pixel_values"].to(device)
label_ids = batch["labels"].to(device)
# forward pass
outputs = model(pixel_values=pixel_values, labels=label_ids)
# get the loss
loss = outputs.loss
valid_loss += loss.item()
# free the GPU memory
logits = outputs.logits.detach().cpu()
# add the predictions to the list
predictions.extend(logits.argmax(dim=-1).tolist())
# add the labels to the list
labels.extend(label_ids.tolist())
# make the EvalPrediction object that the compute_metrics function expects
eval_prediction = EvalPrediction(predictions=predictions, label_ids=labels)
# compute the metrics
metrics = compute_metrics(eval_prediction)
# print the stats
print()
print(f"Epoch: {epoch}, Step: {current_step}, Train Loss: {train_loss / save_steps:.4f}, " +
f"Valid Loss: {valid_loss / n_valid_steps:.4f}, Accuracy: {metrics['accuracy']}, " +
f"F1 Score: {metrics['f1']}")
print()
# log the metrics
summary_writer.add_scalar("valid_loss", valid_loss / n_valid_steps, global_step=current_step)
summary_writer.add_scalar("accuracy", metrics["accuracy"], global_step=current_step)
summary_writer.add_scalar("f1", metrics["f1"], global_step=current_step)
# save the model
model.save_pretrained(f"./vit-base-food/checkpoint-{current_step}")
image_processor.save_pretrained(f"./vit-base-food/checkpoint-{current_step}")
# get the model back to train mode
model.train()
# reset the train and valid loss
train_loss, valid_loss = 0, 0
### training code below ###
# get the batch & convert to tensor
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
# forward pass
outputs = model(pixel_values=pixel_values, labels=labels)
# get the loss
loss = outputs.loss
# backward pass
loss.backward()
# update the weights
optimizer.step()
# zero the gradients
optimizer.zero_grad()
# log the loss
loss_v = loss.item()
train_loss += loss_v
# increment the step
current_step += 1
progress_bar.update(1)
# log the training loss
summary_writer.add_scalar("train_loss", loss_v, global_step=current_step)
# %% [markdown]
# # Performing Inference
# %%
# load the best model, change the checkpoint number to the best checkpoint
# if the last checkpoint is the best, then ignore this cell
best_checkpoint = 7000
# best_checkpoint = 150
model = ViTForImageClassification.from_pretrained(f"./vit-base-food/checkpoint-{best_checkpoint}").to(device)
# model = ViTForImageClassification.from_pretrained(f"./vit-base-skin-cancer/checkpoint-{best_checkpoint}").to(device)
# %%
get_prediction(model, "https://images.pexels.com/photos/858496/pexels-photo-858496.jpeg?auto=compress&cs=tinysrgb&w=600&lazy=load")
# %%
def get_prediction_probs(model, url_or_path, num_classes=3):
# load the image
img = load_image(url_or_path)
# preprocessing the image
pixel_values = image_processor(img, return_tensors="pt")["pixel_values"].to(device)
# perform inference
output = model(pixel_values)
# get the top k classes and probabilities
probs, indices = torch.topk(output.logits.softmax(dim=1), k=num_classes)
# get the class labels
id2label = model.config.id2label
classes = [id2label[idx.item()] for idx in indices[0]]
# convert the probabilities to a list
probs = probs.squeeze().tolist()
# create a dictionary with the class names and probabilities
results = dict(zip(classes, probs))
return results
# %%
# example 1
get_prediction_probs(model, "https://images.pexels.com/photos/406152/pexels-photo-406152.jpeg?auto=compress&cs=tinysrgb&w=600")
# %%
# example 2
get_prediction_probs(model, "https://images.pexels.com/photos/920220/pexels-photo-920220.jpeg?auto=compress&cs=tinysrgb&w=600")
# %%
# example 3
get_prediction_probs(model, "https://images.pexels.com/photos/3338681/pexels-photo-3338681.jpeg?auto=compress&cs=tinysrgb&w=600")
# %%
# example 4
get_prediction_probs(model, "https://images.pexels.com/photos/806457/pexels-photo-806457.jpeg?auto=compress&cs=tinysrgb&w=600", num_classes=10)
# %%
get_prediction_probs(model, "https://images.pexels.com/photos/1624487/pexels-photo-1624487.jpeg?auto=compress&cs=tinysrgb&w=600")