Understanding RAG (Retrieval-Augmented Generation) with a practical (simple) example

Michaël Scherding
5 min readMay 15, 2024

--

Introduction

Retrieval-Augmented generation (RAG) is a powerful technique that combines information retrieval with text generation. This method is particularly useful when you need to generate responses based on specific context extracted from a set of documents.

In this article, we will explore the logic of RAG through a simple and basic example using Python, OpenAI, and a few other libraries. Follow this step-by-step guide to understand how RAG works and how you can implement it in your own projects.

Setting up the environment

To begin, we need to set up our working environment by installing the necessary libraries. Make sure you have access to Google Colab or a similar Python environment.

Installing libraries

!pip install openai pymupdf faiss-cpu scikit-learn
  • OpenAI: A library to interact with OpenAI’s language models for tasks like text generation and conversation.
  • PyMuPDF: A library for accessing and manipulating PDF documents, including text extraction.
  • FAISS: A library for efficient similarity search and clustering of dense vectors.
  • Scikit-learn: A machine learning library providing tools for data analysis, modeling, and preprocessing.

Extracting text from PDF files

The first step involves extracting text from PDF files we will use as our context source. For this, we will use the PyMuPDF library.

Extracting text from PDFs

from google.colab import files
import fitz # PyMuPDF

# Upload PDF files to Google Colab
uploaded = files.upload()

# Function to extract text from a PDF
def extract_text_from_pdf(pdf_path):
doc = fitz.open(pdf_path)
text = ""
for page_num in range(len(doc)):
page = doc.load_page(page_num)
text += page.get_text()
return text

# Extract text from all uploaded PDF files
pdf_texts = {}
for pdf_file in uploaded.keys():
if pdf_file.endswith(".pdf"):
pdf_texts[pdf_file] = extract_text_from_pdf(pdf_file)

# Display extracted text from each PDF file
for pdf_file, text in pdf_texts.items():
print(f"--- {pdf_file} ---")
print(text[:500]) # Display the first 500 characters of each document
print("\n")

In this step, we use PyMuPDF to open and read the contents of the uploaded PDF files. We define a function extract_text_from_pdf that extracts text from each page of the PDF and concatenates it into a single string. Then, we iterate over all uploaded PDF files and store their textual content in a dictionary pdf_texts.

Text vectorization and creating the FAISS index

To perform efficient searches, we need to convert our text data into numerical vectors. We will use the TF-IDF (Term Frequency-Inverse Document Frequency) vectorizer from Scikit-learn for this task. After vectorizing the text, we will use FAISS to create an index to search through the vectors quickly.

Vectorizing text and creating the FAISS index

from sklearn.feature_extraction.text import TfidfVectorizer
import faiss

# Convert text documents to TF-IDF vectors
documents = list(pdf_texts.values())
vectorizer = TfidfVectorizer()
doc_vectors = vectorizer.fit_transform(documents).toarray()

# Create a FAISS index
dimension = doc_vectors.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(doc_vectors)

We use the TfidfVectorizer from Scikit-learn to transform the text documents into TF-IDF vectors. This step converts each document into a numerical vector that represents the importance of each term in the document relative to the corpus. The fit_transform method fits the vectorizer to the documents and transforms them into TF-IDF vectors. FAISS (Facebook AI Similarity Search) is a library for efficient similarity search and clustering of dense vectors. We create an index using the IndexFlatL2 method, which builds a flat (non-hierarchical) index based on L2 (Euclidean) distance. We then add our document vectors to this index using the add method.

Searching the index

With our text data vectorized and indexed, we can now perform searches. We will define a function to search the index for the most relevant documents based on a query.

Searching the index

def search_documents(query, top_k=5):
query_vector = vectorizer.transform([query]).toarray()
distances, indices = index.search(query_vector, top_k)
results = [(documents[i], distances[0][i]) for i in indices[0]]
return results

# Example query
query = "impact of climate change"
search_results = search_documents(query)
for result in search_results:
print(result)

The search_documents function takes a query and the number of top results to return (top_k). It transforms the query into a TF-IDF vector using the same vectorizer. It uses the search method of the FAISS index to find the closest vectors (documents) to the query vector. The function returns the top top_k documents along with their distances from the query vector. We demonstrate the search function with an example query about the impact of climate change. The search results are printed, showing the most relevant documents and their distances.

Using OpenAI API for Retrieval-Augmented Generation (RAG)

In this step, we’ll combine the context retrieved from our documents with GPT-4 to generate responses. The context will provide the necessary information to the model to produce more accurate and relevant answers.

Generating augmented responses with OpenAI API

import openai

# Set your OpenAI API key
openai.api_key = 'your_openai_api_key' # Replace with your own API key

def generate_augmented_response(query):
search_results = search_documents(query)
context = "\n".join([result[0] for result in search_results])

prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
max_tokens=200
)
return response.choices[0].message['content'].strip()

# Example usage
query = "What is the impact of climate change on marine life?"
response = generate_augmented_response(query)
print(response)

Replace 'your_openai_api_key' with your actual OpenAI API key. The generate_augmented_response function first retrieves the most relevant documents using the search_documents function. It then combines the retrieved context into a single prompt. The prompt is formatted to provide the context followed by the query. We use the openai.ChatCompletion.create method to generate a response from the GPT-4 model. The messages parameter specifies the roles and contents of the conversation, starting with a system message that sets the context for the assistant and followed by the user query. The max_tokens parameter limits the length of the generated response. The function returns the generated response after stripping any leading or trailing whitespace.

Conclusion

In this article, we have walked through the process of setting up a simple Retrieval-Augmented Generation (RAG) system using Python, OpenAI, and several other libraries. Here are the key steps we covered:

  1. Setting up the environment: Installing the necessary libraries.
  2. Extracting text from PDFs: Using PyMuPDF to extract text from PDF files.
  3. Vectorizing text and creating the FAISS index: Converting text to TF-IDF vectors and creating an index for efficient searching.
  4. Searching the index: Implementing a function to search the index for relevant documents.
  5. Generating augmented responses with OpenAI API: Using GPT-4 to generate responses based on the retrieved context.

By following these steps, you can create a basic RAG system to enhance your text generation tasks with relevant context from your documents. This approach can be expanded and scaled for more complex and larger datasets, making it a versatile tool for various applications.

See ya 🤟

--

--