Chatting with docs 2.0

The rise of Vision Language Models

Anukriti Ranjan
9 min readNov 11, 2024

In late 2022, Retrieval-Augmented Generation (RAG) emerged as a popular approach to engaging in conversations with documents. This enabled users to ask multi-turn questions about PDFs and other documents, facilitating a deeper understanding of the text. The foundational strategy involved breaking down documents into smaller chunks, embedding these chunks for retrieval, and storing them for future reference. When a user query arrived, the system would embed it similarly, retrieve the most relevant chunks using metrics like cosine similarity, and send these chunks as context to a generative language model (LLM) to generate a response.

Over time, the workflow became more “agentic,” meaning that LLMs were increasingly leveraged to refine the retrieval process itself. Techniques such as structured information extraction, metadata filtering, and query reformulation were employed to enhance relevance. For example, the model could generate an answer without context and use that answer (rather than the initial query) for retrieval or even perform “step-by-step” query planning to guide the retrieval process towards more pertinent documents.

The complexity of the workflow comes with its own set of challenges that add to engineering overhead. It would be worth it if the resulting accuracy was high. Another major problem with text chunking and embedding is that the structured extraction of text from documents is not that straightforward. Consider a chip datasheet like here or a document like the SDG report by Niti Aaayog where pages can look like :

Source: SDG INDIA Index & Dashboard 2020–21

Enter Vision Language Models

Vision Language Models (VLMs) provide a more seamless solution by combining image and text inputs to answer queries based on visual context. With Colpali, an adapted VLM, we can now embed entire document images and compare them to query embeddings for a more grounded approach. The retrieved images, together with the query, are processed by a larger VLM to derive answers based on the document’s actual content. Hugging Face offers an excellent cookbook for implementing this approach.

An approximate workflow with ColPali is as under.

Image by the author

PaliGemma and ColPali: Under the Hood

The technical backbone of ColPali is derived from PaliGemma, a smaller 3-billion-parameter model from DeepMind. This model is known for its contrastive vision encoder and transformer decoder (Gemma 2B LM) architecture, designed for high-performance transfer learning.

The components of PaliGemma are depicted in the figure from the paper: PaliGemma: A versatile 3B VLM for transfer

Source: https://arxiv.org/abs/2407.07726

Thus, the PaliGemma architecture leverages a contrastive vision encoder and transformer decoder (Gemma 2B LM). Let’s dive into it a bit.

Vision encoder

A fundamental constituent of a VLM is a vision encoder. Vision encoders in models like CLIP and SigLIP are neural networks that process an input image and output a representation (embedding) that captures high-level semantic information about the image. These embeddings are not used for detecting edges or specific objects directly but are optimized to capture features that allow the model to understand and align visual content with textual descriptions. Typically, these vision encoders will use a vision Transformers (ViTs) that divide an image into patches (e.g., 14x14 pixels) and treat each patch as a “token,” similar to how words are treated in NLP. Each patch is embedded into a vector and processed using self-attention mechanisms to capture relationships between different parts of the image.

To train a vision encoder, roughly the following methodology is followed.

Image by the author

PaliGemma is trained on next token prediction. Then, how does ColPali use it to extract image embeddings?

PaliGemma has the following components.

Image by the author

ColPali modifies the above architecture as under:

Image by the author

A major addition is the Custom Text projector that downscales the 2048 dimension embeddings to a dimension of 128. Also, LoRA layers are introduced while finetuning the model in a parameter efficient manner. The sequence length of image patches has been increased to 1024. The original Paligemma has also been trained on 224*224 , 448*448 and 896*896 pixel images and progressively such the their Stage1 is at resolution 224px and serves as a useful base model for many tasks.

A typical inference code for ColPali looks as under. In this example, I have taken this page from a McKinsey report on the state of AI.

Source: The state of AI in early 2024: Gen AI adoption spikes and starts to generate value -McKinsey & Company

The code has been taken from the huggingface model page for vidore/colpali-v1.2.

!pip install -U -q byaldi pdf2image qwen-vl-utils transformers
!sudo apt-get install -y poppler-utils

import os
from pdf2image import convert_from_path


def convert_pdfs_to_images(pdf_folder):
pdf_files = [f for f in os.listdir(pdf_folder) if f.endswith(".pdf")]
all_images = {}

for doc_id, pdf_file in enumerate(pdf_files):
pdf_path = os.path.join(pdf_folder, pdf_file)
images = convert_from_path(pdf_path)
all_images[doc_id] = images

return all_images
all_images = convert_pdfs_to_images("/content/pdfs/")

from typing import cast

import torch
from PIL import Image

from colpali_engine.models import ColPali, ColPaliProcessor

model_name = "vidore/colpali-v1.2"

model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda:0", # or "mps" if on Apple Silicon
).eval()

processor = ColPaliProcessor.from_pretrained(model_name)

# Your inputs
images = [
all_images[0][5]
]
queries = [
"which two use cases of AI within marketing and sales are reported by 15 percent or more of respondents?",

]

# Process the inputs
batch_images = processor.process_images(images).to(model.device)
batch_queries = processor.process_queries(queries).to(model.device)

# Forward pass
with torch.no_grad():
image_embeddings = model(**batch_images)
querry_embeddings = model(**batch_queries)

if you look into the batch images and batch queries, they look as follows.

You can re-construct the image from the pixel values with the help of the following code.

import torch
import matplotlib.pyplot as plt
import numpy as np

# Example tensor with shape [1, 3, 448, 448]
my_tensor = batch_images.pixel_values

# Remove the batch dimension (convert [1, 3, 448, 448] to [3, 448, 448])
my_tensor = my_tensor.squeeze(0)

# Permute the tensor to convert from [3, 448, 448] (C, H, W) to [448, 448, 3] (H, W, C)
my_tensor = my_tensor.permute(1, 2, 0)

# Convert the tensor to a NumPy array and clip to valid range if necessary
# Ensure values are between 0 and 1 for display purposes
image = my_tensor.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min()) # Normalize to [0, 1] if necessary

# Display the image
plt.imshow(image)
plt.axis('off') # Turn off the axis to focus on the image
plt.show()

257152 is the token id for image tokens in PaliGemma, but you may have noticed some other ids at the end as well in batch images. This is because ColPali uses a visual_prompt_prefix “<image><bos>Describe the image.”

So how do we get query and image embeddings separately from the ColPali model? You can take a look into the forward method of this model in pytorch.

    def forward(self, *args, **kwargs) -> torch.Tensor:
# Delete output_hidden_states from kwargs
kwargs.pop("output_hidden_states", None)
if "pixel_values" in kwargs:
kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)

outputs = self.model(*args, output_hidden_states=True, **kwargs) # (batch_size, sequence_length, hidden_size)
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)

# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)

proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)

return proj

The last hidden state ( of shape 1030 , 2048 in case of image embeddings ) projected to 128 dimension by the Custom Text projector Layer is what essentially serves as the embedding for each patch/token. But why 1030 and not 1024 because 448*448/(14*14) = 1024 where 14*14 is the patch size ? Because remember we added a visual_prompt_prefix in the process_images method in the processor.

Now for getting the score for retrieval , ColPali uses a method called MaxSim which is simply the picking of the patch with maximum similarity for each token and then summing these similarity scores to get a final score. An amazing side of ColPali is that you can have some interpretability with it. Since it retains the individual patch and query token embedding, you can actually visualize which token has highest similarity with which image patch. For my experiment, i used the following code.

import torch
import numpy as np
import matplotlib.pyplot as plt

# Remove the batch dimension for simplicity
image_embeddings = image_embeddings.squeeze(0) # [1030, 128]
querry_embeddings = querry_embeddings.squeeze(0) # [36, 128]

# Compute ColBERT-style similarity scores
scores = torch.matmul(querry_embeddings, image_embeddings.T) # [36, 1030]

# Define parameters
patch_size = 14 # Each patch is 14x14 pixels
grid_size = 32 # Number of patches along one dimension (448 / 14)

# Get the top `n` patches for each query token
n = 5 # Number of top patches to visualize per query token

# Iterate over each query token and visualize the top `n` patches
for token_idx in range(querry_embeddings.shape[0]): # Iterate over query tokens
# Get the top `n` patch indices for the current query token
top_patches = torch.topk(scores[token_idx, :1024], n).indices.tolist() # Limit to the first 1024 indices
query_token = processor.decode(batch_queries.input_ids[0][token_idx])
# Create a mask for visualization
mask = np.zeros((448, 448))

# Highlight the identified patches in the mask
for idx in top_patches:
if idx < 1024: # Ensure we only use valid image patch indices
x = idx % grid_size # Horizontal coordinate (column)
y = idx // grid_size # Vertical coordinate (row)
mask[y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size] = 1

# Load the original image (assuming `image_tensor` of shape [3, 448, 448])
image_array = batch_images.pixel_values.squeeze(0).permute(1, 2, 0).cpu().numpy() # Convert to [H, W, C] format
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) # Normalize for display

# Plot the image with the heatmap overlay for the current token
plt.figure(figsize=(10, 10))
plt.imshow(image_array)
plt.imshow(mask, cmap='hot', alpha=0.5) # Overlay the mask with transparency
plt.axis('off')
plt.title(f'Top {n} Patches for Query Token {token_idx}: {query_token}')
plt.show()

For example:

You are advised to use the code from their official repo : https://github.com/illuin-tech/colpali/tree/main/colpali_engine/interpretability

ColPali helps in retrieval only. You will still have to send the retrieved images to another Vision language model. In the huggingface cookbook , Qwen2 VL is used but it can be any model like Llama Vision 11B or Gemini Flash etc.

Why Retrieval Still Matters Despite Expanding Model Context

One may question why retrieval is still necessary when model contexts are expanding (with capabilities of up to 128k tokens). Here, retrieval remains valuable because it narrows down the scope, reducing unnecessary computation and honing in on relevant document segments. As models grow in size and capability, retrieval frameworks ensure efficient and focused responses, crucial for user experience and system performance.

Conclusion

In this piece, we explored the evolution of document-interaction technologies, from RAG to advanced Vision Language Models like ColPali. While the journey from simple text embedding to multimodal VLMs has introduced new complexities, the payoff is evident: improved accuracy, interpretability, and alignment with user intent. With ColPali and similar technologies, we’re moving towards a future where querying a document becomes a dynamic, seamless experience across text and visuals alike. Even Anthropic has recently rolled out their visual pdf support. Alex Albert mentions, “Up until today, when you attached a PDF in claude dot ai, we would use a text extraction service to grab the text and send that to Claude in the prompt.”.

Source: x.com

It does look like that VLMs will multiply the use cases of AI manifold, not just for chatting with your docs but also for other capabilities where image understanding in its semantic sense is needed.

References:

--

--

Anukriti Ranjan
Anukriti Ranjan

Written by Anukriti Ranjan

searching for “simplicity on the far side of complexity” | www.linkedin.com/in/anukriti-ranjan

No responses yet