RAG app with Postgres and pgvector

RAG (Retrieval Augmented Generation) apps are really popular right now, and we decided to develop one exclusively using Postgres and pgvector. In this blog post, I will talk a little bit about what RAG is, break down how to build a RAG app from scratch, share the fundamental elements of building a RAG app and offer some helpful tips along the way. 

All the code is available on GitHub, and I also recorded my screen running the app to show you how it works.

What is Retrieval Augmented Generation, or RAG?

Retrieval Augmented Generation (RAG) has become the preferred method for making Language Model (LLM) applications smarter. We achieve this by adding our own data that LLMs don't usually know about.

RAG combines traditional language generation models with retrieval-based methods, like using our own data from databases such as Postgres, to enhance the quality and relevance of generated text. We gather relevant data for a question or task and include it as context for the LLM. By incorporating the retrieval step and the additional context, RAG improves the coherence, accuracy and specificity of generated text, making it particularly useful for tasks like question answering, summarization and dialogue creation.

All of these steps lead to a deeper understanding of the topic by the LLM, hence making it smarter. RAG is also known to prevent inaccuracies known as “hallucinations” in generated responses.
 

Motivation behind the app

In the last few months, I've been presenting about pgvector at various conferences and meetups, and connecting with customers eager to explore the potential of Postgres and pgvector. Many of them are interested in incorporating their own data into LLMs to develop chatbot-style applications that draw on their domain-specific knowledge.

I've observed some consistent trends in their requests, driving the motivation behind the app:

  • Users want the ability to inject different types of data sources, from Jira and Github issues to Confluence documentation, blog posts, internal training materials and PDF documents into their RAG apps.
  • Due to data privacy concerns, they lean towards local LLM deployments to avoid sending data to external APIs like OpenAI API. 
  • They want to store and query vector data directly in Postgres using pgvector, simplifying data management processes. This is particularly understandable since they are already familiar with Postgres and use it in production.
  • They want to control and regulate the vector search based on the role or privilege of the user asking the question. Different projects, stakeholders and departments need access to their own information while restricting access for others.

Limitations with RAG apps

While the Retrieval-Augmented Generation (RAG) architecture offers many benefits, it also comes with several challenges, especially when running Large Language Models (LLMs) locally:

  • Running LLMs on a CPU is challenging, as most models are optimized for GPUs (e.g., Llama models). It is important to note that running LLMs locally is not a requirement of RAG architecture; we’ve chosen this approach in our app to align with trends we’ve observed in customer requirements.
  • Development and testing locally can be time-consuming due to the memory, cache and CPU constraints of typical laptops.
  • The instructions of the RAG are limited by the model's context window, which refers to the number of tokens an LLM can process as input when generating responses. This fixed token limit can restrict the model's ability to understand and generate responses based on long or complex inputs.
  • Scaling the system to handle increased loads or larger models can be difficult, requiring careful planning and resource management.
  • The cost of the environment for AWS instances, particularly GPU-optimized instances like g5.2xlarge, can be quite high.

Process flow of RAG applications

When designing the application, I envisioned the workflow as depicted in the diagram below:

 

Process flow of RAG applications

Step 1: Data Processing

  • Ingest PDFs and documents.
  • Create data chunks.
  • Encode chunks as vectors and store them in PostgreSQL using pgvector.

Step 2: Embedding Model

  • Convert text chunks into embedding vectors.
  • Prepare data for the chat model.

Step 3: User Query

  • Allow users to input questions.
  • Use queries to prompt the system.

Step 4: Retrieve Relevant Sections

  • Identify top N relevant document sections using vectors.
  • Optimize model’s token usage.

 Step 5: Create Composite Prompt

  • Generate a prompt with relevant vectors, system prompts, and the user's question.
  • Include recent conversation history for context.

Step 6: Send Prompt to Chat Model

  • Forward the composite prompt to the chat model.

Step 7: Provide Answer

  • Retrieve response from the chat model.
  • Send the response back to the user.

Then started developing and putting all these building blocks together.

Application Architecture

The application follows a standard RAG (Retrieval-Augmented Generation) workflow, as we discussed earlier. The key elements here are Postgres and pgvector for storing vectors and building the chatbot, and also the aspect of running LLMs locally. These components form the backbone of our app design.

Requirements

  • PostgreSQL (version 12 or higher, as pgvector requires 12+)
  • pgvector
  • Python 3

The application involves three main steps: creating the database, importing data, and initiating the chat functionality. These steps are encapsulated in `app.py`, and you can run the application using the following commands:

python app.py --help
usage: app.py [-h] {create-db,import-data,chat} ...

Application Description

options:
  -h, --help            show this help message and exit

Subcommands:
  {create-db,import-data,chat}
                        Display available subcommands
    create-db           Create a database
    import-data         Import data
    chat                Use chat feature

Now, let’s take a closer look at the code and go through the implementation details. 

create_db.py

We start with creating a database using ENV parameters (DB_USER, DB_PASSWORD, DB_HOST, DB_PORT). Then we enable the pgvector extension that we installed as part of requirements. Finally, we set up the embeddings table.

CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS embeddings (id serial PRIMARY KEY, doc_fragment text, embeddings vector(4096));

If you prefer to see the full code (create_db.py), please see below:

import os
import psycopg2


def create_db(args, model, device, tokenizer):
    db_config = {
        "user": os.getenv("DB_USER"),
        "password": os.getenv("DB_PASSWORD"),
        "host": os.getenv("DB_HOST"),
        "port": os.getenv("DB_PORT"),
    }

    conn = psycopg2.connect(**db_config)
    conn.autocommit = True  # Enable autocommit for creating the database

    cursor = conn.cursor()
    cursor.execute(
        f"SELECT 1 FROM pg_database WHERE datname = '{os.getenv('DB_NAME')}';"
    )
    database_exists = cursor.fetchone()
    cursor.close()

    if not database_exists:
        cursor = conn.cursor()
        cursor.execute(f"CREATE DATABASE {os.getenv('DB_NAME')};")
        cursor.close()
        print("Database created.")

    conn.close()

    db_config["dbname"] = os.getenv("DB_NAME")
    conn = psycopg2.connect(**db_config)
    conn.autocommit = True

    cursor = conn.cursor()
    cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
    cursor.close()

    cursor = conn.cursor()
    cursor.execute(
        "CREATE TABLE IF NOT EXISTS embeddings (id serial PRIMARY KEY, doc_fragment text, embeddings vector(4096));"
    )
    cursor.close()

    print("Database setup completed.")

import_data.py

After we run the create-db command, next we have to run the import-data command to import our documents. When import_data.py is executed, it invokes db.py and embedding.py. This is what happens when during the import-data process:

  • Connect to DB
  • Read pdf files
    • Takes a PDF file path as input, reads the text content of each page in the PDF file, splits it into lines, and returns the lines as a list
  • Generate embeddings
    • Takes input text, tokenizes it, passes it through the LLM, retrieves the hidden states from the model's output, calculates the mean embedding, and returns both the original text and its corresponding embedding vector
  • Store embeddings in the database
    • Store the document fragments and their embeddings in the database

You can see import_data.py below:

import numpy as np

from db import get_connection
from embedding import generate_embeddings, read_pdf_file


def import_data(args, model, device, tokenizer):
    data = read_pdf_file(args.data_source)

    embeddings = [
        generate_embeddings(tokenizer=tokenizer, model=model, device=device, text=line)
        for line in data
    ]

    conn = get_connection()
    cursor = conn.cursor()

    # Store each embedding in the database
    for i, (doc_fragment, embedding) in enumerate(embeddings):
        cursor.execute(
            "INSERT INTO embeddings (id, doc_fragment, embeddings) VALUES (%s, %s, %s)",
            (i, doc_fragment, embedding[0]),
        )
    conn.commit()

    print(
        "import-data command executed. Data source: {}".format(
            args.data_source
        )
    )

You can see db.py below:

import os
import psycopg2


def get_connection():
    conn = psycopg2.connect(
        dbname=os.getenv("DB_NAME"),
        user=os.getenv("DB_USER"),
        password=os.getenv("DB_PASSWORD"),
        host=os.getenv("DB_HOST"),
        port=os.getenv("DB_PORT"),
    )

    return conn

You can see embedding.py below:

# importing all the required modules
import PyPDF2
import torch
from transformers import pipeline

def generate_embeddings(tokenizer, model, device, text):
    inputs = tokenizer(
        text, return_tensors="pt", truncation=True, max_length=512
    ).to(device)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    return text, outputs.hidden_states[-1].mean(dim=1).tolist()


def read_pdf_file(pdf_path):
    pdf_document = PyPDF2.PdfReader(pdf_path)

    lines = []
    for page_number in  range(len(pdf_document.pages)):
        page = pdf_document.pages[page_number]

        text = page.extract_text()

        lines.extend(text.splitlines())

    return lines

chat.py

After running create-db and import-data commands, we need to run the chatcommand. When chat.py is executed, it invokes rag.py. We will talk about rag.py more in the next section.

The chat process defines a chat function that facilitates an interactive chat with a user. It continuously prompts the user for questions, generates responses using a specified model, and displays the response to the user until they choose to exit the chat.

def chat(args, model, device, tokenizer):

The line above defines a function named chat that takes four arguments:

  • args: Additional arguments (if any) that may be passed to the function.
  • model: A PyTorch model used for generating responses to questions.
  • device: The device (CPU or GPU) on which the model is running.
  • tokenizer: An instance of a tokenizer used for tokenizing input questions.
answer = rag_query(tokenizer=tokenizer, model=model, device=device, query=question) 

This line above calls the rag_query function with the provided tokenizer, model, device, and user question. It generates a response to the question.

Forming the answer is done by calling the rag_query function and we will explore it in detail while reviewing the rag.py

You can see chat.py below:

from rag import rag_query


def chat(args, model, device, tokenizer):
    print("Chat started. Type 'exit' to end the chat.")

    while True:
        question = input("Ask a question: ")

        if question.lower() == "exit":
            break

        answer = rag_query(tokenizer=tokenizer, model=model, device=device, query=question)

        print(f"You Asked: {question}")
        print(f"Answer: {answer}")

    print("Chat ended.")

rag.py

As the name suggests, this script contains the essential RAG logic necessary for the main functionality of this application. Let’s look at it closer.

template = """<s>[INST]
You are a friendly documentation search bot.
Use following piece of context to answer the question.
If the context is empty, try your best to answer without it.
Never mention the context.
Try to keep your answers concise unless asked to provide details.

Context: {context}
Question: {question}
[/INST]</s>
Answer:
"""

The template above provides a structured format for presenting a question, context (if available), and the bot's answer. It ensures consistency in response presentation and provides clear instructions (and context, if applicable) for the model on how to handle each query.

get_retrieval_condition function

Every RAG relies on a retrieval mechanism, which is a core component of RAG architecture. We retrieve data from our own database and use this data to provide context for the LLM. That’s why we built a retrieval condition to fetch relevant information (embeddings) using the rag_query we developed.

We use the get_retrieval_condition function to generate an SQL condition for retrieving relevant embeddings from the database.

def get_retrieval_condition(query_embedding, threshold=0.7):
    # Convert query embedding to a string format for SQL query
    query_embedding_str = ",".join(map(str, query_embedding))

    # SQL condition for cosine similarity
    condition = f"(embeddings <=> '{query_embedding_str}') < {threshold} ORDER BY embeddings <=> '{query_embedding_str}'"
    return condition

The function constructs an SQL condition to find and order embeddings based on their cosine similarity (<=>) to a given query embedding. It takes query embedding and a threshold value as input:

  • query_embedding: A list or array representing the embedding vector of the user's query.
  • threshold: A float value (default = 0.7) specifying the maximum cosine distance allowed for an embedding to be considered relevant. The lower the threshold, the closer the match must be.
query_embedding_str = ",".join(map(str, query_embedding)) 

This line above converts the query_embedding list into a string format suitable for SQL queries.

condition = f"(embeddings <=> '{query_embedding_str}') < {threshold} ORDER BY embeddings <=> '{query_embedding_str}'"

This line above constructs an SQL condition string that uses the <=> operator to calculate the cosine similarity between the stored embeddings and the query_embedding. It ensures that only embeddings with a cosine similarity less than the specified threshold are considered. It also orders the results by their cosine similarity to the query_embedding, so the most relevant results are listed first. 

The function returns the constructed SQL condition string.

rag_query function

The rag_query function integrates query embedding generation, document retrieval from a Postgres database, query and context combination, and response generation using an LLM to produce a relevant answer to the input query. It is arguably the most critical component of the RAG system.

Take a look at the whole function and then we’ll walk through it step-by-step.
 

def rag_query(tokenizer, model, device, query):
    # Generate query embedding
    query_embedding = generate_embeddings(
        tokenizer=tokenizer, model=model, device=device, text=query
    )[1]

    # Retrieve relevant embeddings from the database
    retrieval_condition = get_retrieval_condition(query_embedding)

    conn = get_connection()
    register_vector(conn)
    cursor = conn.cursor()
    cursor.execute(
        f"SELECT doc_fragment FROM embeddings WHERE {retrieval_condition} LIMIT 5"
    )
    retrieved = cursor.fetchall()

    rag_query = ' '.join([row[0] for row in retrieved])

    query_template = template.format(context=rag_query, question=query)

    input_ids = tokenizer.encode(query_template, return_tensors="pt")

    # Generate the response
    generated_response = model.generate(input_ids.to(device), max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(generated_response[0][input_ids.shape[-1]:], skip_special_tokens=True)

Let’s break down the rag_query function step by step to understand what each part does.

Generating Query Embeddings

query_embedding = generate_embeddings(
    tokenizer=tokenizer, model=model, device=device, text=query
)[1]

Here, we generate query embedding for the input query. The generate_embeddings function returns a tuple containing the original text and its corresponding embeddings vector. We extract the second element of the returned tuple, which is the query embedding.

Retrieving Relevant Embeddings from the Database

retrieval_condition = get_retrieval_condition(query_embedding)

Here, we create a condition to retrieve relevant documents from the database based on the cosine similarity with the query embedding. 

See the previous section to understand how the get_retrieval_condition function works.

conn = get_connection()
register_vector(conn)
cursor = conn.cursor()
cursor.execute(
    f"SELECT doc_fragment FROM embeddings WHERE {retrieval_condition} LIMIT 5"
)
retrieved = cursor.fetchall()

Here we connect to the database. Then we execute an SQL query to select document fragments from the embeddings table based on the retrieval condition and limit the results to the top 5 most relevant embeddings.

Preparing the Query Template

rag_query = ' '.join([row[0] for row in retrieved])

The code above is concatenating the retrieved document fragments into a single string (rag_query) separated by spaces. It fetches all the document fragments and joins the first element of each tuple in retrieved.

query_template = template.format(context=rag_query, question=query)

Here, we format the template string with the retrieved document fragments (context) and the original query text (the question).

Generating the Response

input_ids = tokenizer.encode(query_template, return_tensors="pt")

tokenizer.encode converts the text to input IDs, and return_tensors="pt" ensures the output is in PyTorch tensor format. We do this to tokenize the query template and convert it to tensor format suitable for model input.

generated_response = model.generate(input_ids.to(device), max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)

The code generates a tokenized input using the provided model. It limits the maximum number of new tokens to 50 (response length) and specifies the end-of-sequence token ID for padding.

generated_response contains the generated text IDs.

return tokenizer.decode(generated_response[0][input_ids.shape[-1]:], skip_special_tokens=True)

Here, the code is decoding the generated response tokens into a human readable string. The output of the rag_query function is the final generated response text.

You can see rag.py below:

from itertools import chain
import torch
from pgvector.psycopg2 import register_vector
from db import get_connection
from embedding import generate_embeddings

from pgvector.psycopg2 import register_vector

template = """<s>[INST]
You are a friendly documentation search bot.
Use following piece of context to answer the question.
If the context is empty, try your best to answer without it.
Never mention the context.
Try to keep your answers concise unless asked to provide details.

Context: {context}
Question: {question}
[/INST]</s>
Answer:
"""

def get_retrieval_condition(query_embedding, threshold=0.7):
    # Convert query embedding to a string format for SQL query
    query_embedding_str = ",".join(map(str, query_embedding))

    # SQL condition for cosine similarity
    condition = f"(embeddings <=> '{query_embedding_str}') < {threshold} ORDER BY embeddings <=> '{query_embedding_str}'"
    return condition


def rag_query(tokenizer, model, device, query):
    # Generate query embedding
    query_embedding = generate_embeddings(
        tokenizer=tokenizer, model=model, device=device, text=query
    )[1]

    # Retrieve relevant embeddings from the database
    retrieval_condition = get_retrieval_condition(query_embedding)

    conn = get_connection()
    register_vector(conn)
    cursor = conn.cursor()
    cursor.execute(
        f"SELECT doc_fragment FROM embeddings WHERE {retrieval_condition} LIMIT 5"
    )
    retrieved = cursor.fetchall()

    rag_query = ' '.join([row[0] for row in retrieved])

    query_template = template.format(context=rag_query, question=query)

    input_ids = tokenizer.encode(query_template, return_tensors="pt")

    # Generate the response
    generated_response = model.generate(input_ids.to(device), max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(generated_response[0][input_ids.shape[-1]:], skip_special_tokens=True)

app.py

We covered most of the app.py at the beginning of the post. See create-db, import-data, chat commands.

    if hasattr(args, "func"):
        if torch.cuda.is_available():
            device = "cuda"
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
        else:
            device = "cpu"
            bnb_config = None

        tokenizer = AutoTokenizer.from_pretrained(
            os.getenv("TOKENIZER_NAME"),
            token=os.getenv("HUGGING_FACE_ACCESS_TOKEN"),
        )
        model = AutoModelForCausalLM.from_pretrained(
            os.getenv("MODEL_NAME"),
            token=os.getenv("HUGGING_FACE_ACCESS_TOKEN"),
            quantization_config=bnb_config,
            device_map=device,
            torch_dtype=torch.float16,
        )

        args.func(args, model, device, tokenizer)
    else:
        print("Invalid command. Use '--help' for assistance.")

I think it is worth highlighting the code block above. Here we check if a GPU device is available for us. If CUDA is available, we set the device variable to cuda, indicating that GPU acceleration will be used. Otherwise we set the device to cpu for CPU execution.

If CUDA is available, we initialize a BitsAndBytesConfig object with specific configurations for quantization. It then initializes a tokenizer and model for causal language modeling using HF Transformers library.

You can see app.py below:

import argparse
from enum import Enum
from dotenv import load_dotenv
import os
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from commands.chat import chat
from commands.create_db import create_db
from commands.import_data import import_data

load_dotenv()


class Command(Enum):
    CREATE_DB = "create-db"
    IMPORT_DATA = "import-data"
    CHAT = "chat"


def main():
    parser = argparse.ArgumentParser(description="Application Description")

    subparsers = parser.add_subparsers(
        title="Subcommands",
        dest="command",
        help="Display available subcommands",
    )

    # create-db command
    subparsers.add_parser(
        Command.CREATE_DB.value, help="Create a database"
    ).set_defaults(func=create_db)

    # import-data command
    import_data_parser = subparsers.add_parser(
        Command.IMPORT_DATA.value, help="Import data"
    )
    import_data_parser.add_argument(
        "data_source", type=str, help="Specify the PDF data source"
    )
    import_data_parser.set_defaults(func=import_data)

    # chat command
    chat_parser = subparsers.add_parser(
        Command.CHAT.value, help="Use chat feature"
    )
    chat_parser.set_defaults(func=chat)

    args = parser.parse_args()

    if hasattr(args, "func"):
        if torch.cuda.is_available():
            device = "cuda"
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
        else:
            device = "cpu"
            bnb_config = None

        tokenizer = AutoTokenizer.from_pretrained(
            os.getenv("TOKENIZER_NAME"),
            token=os.getenv("HUGGING_FACE_ACCESS_TOKEN"),
        )
        model = AutoModelForCausalLM.from_pretrained(
            os.getenv("MODEL_NAME"),
            token=os.getenv("HUGGING_FACE_ACCESS_TOKEN"),
            quantization_config=bnb_config,
            device_map=device,
            torch_dtype=torch.float16,
        )

        args.func(args, model, device, tokenizer)
    else:
        print("Invalid command. Use '--help' for assistance.")


if __name__ == "__main__":
    main()

Future improvements

Looking ahead, there are some ideas to enhance the pgvector-rag-app. First, developing a user interface would be beneficial. I have started experimenting with Streamlit, which seems promising for creating a quick demo interface. Since this app is mainly for demonstration purposes, you have the flexibility to code it using your preferred frontend stack.

For the demo, I manually set up the AWS instance. However, if you plan to use this app regularly, automating the setup process for your instances is highly recommended as part of your infrastructure automation strategy.
The app has only been tested with a single PDF document so far, so it needs more work on handling multiple PDFs. It is also possible that there are better models to handle this task and it is always a good idea to try out different models and evaluate. 

Right now, the app doesn't have a way to identify which user is making a query. By tailoring queries based on users' roles and privileges, we can personalize their experience and meet stricter security requirements.
 

Summary

Constructing a RAG application solely with Postgres and pgvector is entirely feasible. However, it serves as a prime example that pgvector alone is just one part of the equation. While the combination of Postgres and pgvector enables us to utilize Postgres as a vector database, a complete AI application requires more.

Discover our EDB Postgres AI, an integrated platform designed for modern analytical and AI workloads. Enhanced by our new Postgres AI extension, our platform speeds up the development and deployment of your AI applications. With robust security and comprehensive support, you can build enterprise-grade AI apps quickly and effectively.

Share this

Relevant Blogs

Next-Gen PostgreSQL: From ACID to AI

Many organisations have used Traditional OLTP databases for many years for different use cases because Excel handles your structure data with ACID compliance. However, it often needs to catch up...
May 29, 2024

More Blogs

Protecting Against SQL Injection

In March of 2024 CISA issued the following advisory related to SQL injection (SQLi): Secure by Design Alert - Eliminating SQL Injection Vulnerabilities in Software. SQL Injection is one of...
April 29, 2024