Code for How to Upscale Images using Stable Diffusion in Python Tutorial


View on Github

Open In Colab

stable_diffusion_upscaler.py

# %%
!pip install -qU diffusers transformers accelerate scipy safetensors

# %% [markdown]
# # Hugging Face Implementation

# %%
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch

# %%
# load model and scheduler
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

# %%
def get_low_res_img(url, shape):
    response = requests.get(url)
    low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
    low_res_img = low_res_img.resize(shape)
    return low_res_img

# %%
url = "https://cdn.pixabay.com/photo/2017/02/07/16/47/kingfisher-2046453_640.jpg"
shape = (200, 128)
low_res_img = get_low_res_img(url, shape)

low_res_img

# %%
prompt = "an aesthetic kingfisher"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image

# %%
prompt = "an aesthetic kingfisher, UHD, 4k, hyper realistic, extremely detailed, professional, vibrant, not grainy, smooth"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image

# %%
upscaled_interpolation = low_res_img.resize((800, 512))
upscaled_interpolation

# %%
url = "https://cdn.pixabay.com/photo/2022/06/14/20/57/woman-7262808_1280.jpg"
shape = (200, 128)
low_res_img = get_low_res_img(url, shape)

low_res_img

# %%
prompt = "an old lady"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image

# %%
prompt = "an iranian old lady with black hair, brown scarf, rock background"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image

# %%
upscaled_interpolation = low_res_img.resize((800, 512))
upscaled_interpolation

# %%
url = "https://cdn.pixabay.com/photo/2017/12/28/07/44/zebra-3044577_1280.jpg"
shape = (450, 128)
low_res_img = get_low_res_img(url, shape)

low_res_img

# %%
prompt = "zebras drinking water"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image

# %%
upscaled_interpolation = low_res_img.resize((1800, 512))
upscaled_interpolation

# %%


# %%


# %% [markdown]
# # Custom
# 

# %%
from tqdm import tqdm
from torch import autocast

# %%
class CustomSDUpscalingPipeline:
    """custom implementation of the Stable Diffusion Upscaling Pipeline"""

    def __init__(self,
                 vae,
                 tokenizer,
                 text_encoder,
                 unet,
                 low_res_scheduler,
                 scheduler,
                 image_processor):

        self.vae = vae
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.unet = unet
        self.low_res_scheduler = low_res_scheduler
        self.scheduler = scheduler
        self.image_processor = image_processor
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'



    def get_text_embeds(self, text):
        """returns embeddings for the given `text`"""

        # tokenize the text
        text_input = self.tokenizer(text,
                                    padding='max_length',
                                    max_length=tokenizer.model_max_length,
                                    truncation=True,
                                    return_tensors='pt')
        # embed the text
        with torch.no_grad():
            text_embeds = self.text_encoder(text_input.input_ids.to(self.device))[0]
        return text_embeds



    def get_prompt_embeds(self, prompt):
        """returns prompt embeddings based on classifier free guidance"""

        if isinstance(prompt, str):
            prompt = [prompt]
        # get conditional prompt embeddings
        cond_embeds = self.get_text_embeds(prompt)
        # get unconditional prompt embeddings
        uncond_embeds = self.get_text_embeds([''] * len(prompt))
        # concatenate the above 2 embeds for classfier free guidance
        prompt_embeds = torch.cat([uncond_embeds, cond_embeds])
        return prompt_embeds


    def transform_image(self, image):
        """convert image from pytorch tensor to PIL format"""

        image = self.image_processor.postprocess(image, output_type='pil')
        return image



    def get_initial_latents(self, height, width, num_channels_latents, batch_size):
        """returns noise latent tensor of relevant shape scaled by the scheduler"""

        image_latents = torch.randn((batch_size, num_channels_latents, height, width)).to(self.device)
        # scale the initial noise by the standard deviation required by the scheduler
        image_latents = image_latents * self.scheduler.init_noise_sigma
        return image_latents



    def denoise_latents(self,
                        prompt_embeds,
                        image,
                        timesteps,
                        latents,
                        noise_level,
                        guidance_scale):
        """denoises latents from noisy latent to a meaningful latents"""

        # use autocast for automatic mixed precision (AMP) inference
        with autocast('cuda'):
            for i, t in tqdm(enumerate(timesteps)):
                # duplicate image latents to do classifier free guidance
                latent_model_input = torch.cat([latents] * 2)
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                latent_model_input = torch.cat([latent_model_input, image], dim=1)

                # predict noise residuals
                with torch.no_grad():
                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        class_labels=noise_level
                    )['sample']

                # separate predictions for unconditional and conditional outputs
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

                # perform guidance
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # remove the noise from the current sample i.e. go from x_t to x_{t-1}
                latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']

        return latents



    def __call__(self,
                 prompt,
                 image,
                 num_inference_steps=20,
                 guidance_scale=9.0,
                 noise_level=20):
        """generates new image based on the `prompt` and the `image`"""

        # encode input prompt
        prompt_embeds = self.get_prompt_embeds(prompt)

        # preprocess image
        image = self.image_processor.preprocess(image).to(self.device)

        # prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.scheduler.timesteps

        # add noise to image
        noise_level = torch.tensor([noise_level], device=self.device)
        noise = torch.randn(image.shape, device=self.device)
        image = self.low_res_scheduler.add_noise(image, noise, noise_level)

        # duplicate image for classifier free guidance
        image = torch.cat([image] * 2)
        noise_level = torch.cat([noise_level] * image.shape[0])

        # prepare the initial image in the latent space (noise on which we will do reverse diffusion)
        num_channels_latents = self.vae.config.latent_channels
        batch_size = prompt_embeds.shape[0] // 2
        height, width = image.shape[2:]
        latents = self.get_initial_latents(height, width, num_channels_latents, batch_size)

        # denoise latents
        latents = self.denoise_latents(prompt_embeds,
                                       image,
                                       timesteps,
                                       latents,
                                       noise_level,
                                       guidance_scale)

        # decode latents to get the image into pixel space
        latents = latents.to(torch.float16)
        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

        # convert to PIL Image format
        image = self.transform_image(image.detach()) # detach to remove any computed gradients

        return image

# %%
# get all the components from the SD Upscaler pipeline
vae = pipeline.vae
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
unet = pipeline.unet
low_res_scheduler = pipeline.low_res_scheduler
scheduler = pipeline.scheduler
image_processor = pipeline.image_processor

custom_pipe = CustomSDUpscalingPipeline(vae, tokenizer, text_encoder, unet, low_res_scheduler, scheduler, image_processor)

# %%
url = "https://cdn.pixabay.com/photo/2017/02/07/16/47/kingfisher-2046453_640.jpg"
shape = (200, 128)
low_res_img = get_low_res_img(url, shape)

low_res_img

# %%
prompt = "an aesthetic kingfisher"
upscaled_image = custom_pipe(prompt=prompt, image=low_res_img)[0]
upscaled_image

# %%
url = "https://cdn.pixabay.com/photo/2018/07/31/22/08/lion-3576045_1280.jpg"
shape = (200, 128)
low_res_img = get_low_res_img(url, shape)

low_res_img

# %%
prompt = "a professional photograph of a lion's face"
upscaled_image = custom_pipe(prompt=prompt, image=low_res_img)[0]
upscaled_image

# %%
upscaled_interpolation = low_res_img.resize((800, 512))
upscaled_interpolation

# %%



Open In Colab