LLM integration๏ƒ

This quickstart will walk you through setting up HeimdaLLM with OpenAIโ€™s LLM. The end result is a function that takes natural language input and returns a trusted SQL SELECT query, constrained to your requirements.

If you wish to also use HeimdaLLM only for validating an existing SQL query, see this quickstart.

Tip

You can also find this quickstart code in a Jupyter Notebook here.

First letโ€™s set up our imports.

import logging
from typing import Sequence

import structlog

from heimdallm.bifrosts.sql.sqlite.select.bifrost import Bifrost
from heimdallm.bifrosts.sql.sqlite.select.envelope import PromptEnvelope
from heimdallm.bifrosts.sql.sqlite.select.validator import ConstraintValidator
from heimdallm.bifrosts.sql.common import FqColumn, JoinCondition, ParameterizedConstraint
from heimdallm.llm_providers import openai

logging.basicConfig(level=logging.ERROR)
structlog.configure(logger_factory=structlog.stdlib.LoggerFactory())

Now letโ€™s set up our LLM integration. Weโ€™ll use OpenAI. Remember to store your OpenAI API token securely.

# load our openai api key secret from the environment.
# create a `.env` file with your openai api secret.
import os
from dotenv import load_dotenv

load_dotenv()
open_api_key = os.getenv("OPENAI_API_SECRET")
assert open_api_key

llm = openai.Client(api_key=open_api_key)

Now weโ€™ll define our database schema. You can dump this directly from your database, but a better method is to dump it out beforehand to a file, manually trim out tables and columns that you donโ€™t want the LLM to know about, and load it from that file. You can also add SQL comments to help explain the schema to the LLM.

# abbreviated example schema
db_schema = """
CREATE TABLE customer (
    customer_id INT NOT NULL,
    store_id INT NOT NULL,
    first_name VARCHAR(45) NOT NULL,
    last_name VARCHAR(45) NOT NULL,
    email VARCHAR(50) DEFAULT NULL,
    address_id INT NOT NULL,
    active CHAR(1) DEFAULT 'Y' NOT NULL,
    create_date TIMESTAMP NOT NULL,
    last_update TIMESTAMP NOT NULL,
    PRIMARY KEY  (customer_id),
);
"""

Letโ€™s define our constraint validator(s). These are used to constrain the SQL query so that it only has access to tables and columns that you allow. For more information on the methods that you can override in the derived class, look here.

class CustomerConstraintValidator(SQLConstraintValidator):
    def requester_identities(self) -> Sequence[ParameterizedConstraint]:
        return [
            ParameterizedConstraint(
                column="customer.customer_id",
                placeholder="customer_id",
            ),
        ]

    def parameterized_constraints(self) -> Sequence[ParameterizedConstraint]:
        return []

    def select_column_allowed(self, column: FqColumn) -> bool:
        return True

    def allowed_joins(self) -> Sequence[JoinCondition]:
        return []

    def max_limit(self) -> int | None:
        return None


validator = CustomerConstraintValidator()

Weโ€™ll define our prompt envelope. This adds additional context to any human input so that the LLM is guided to produce a correct response.

envelope = PromptEnvelope(
    llm=llm,
    db_schema=db_schema,
    validators=[validator],
)

Now we can bring everything together into a ๐ŸŒˆ Bifrost

bifrost = Bifrost(
    prompt_envelope=envelope,
    llm=llm,
    constraint_validators=[validator],
)

You can now traverse untrusted human input with the Bifrost.

query = bifrost.traverse("Show me my email")
print(query)

The output should be something like:

SELECT customer.email
FROM customer
WHERE customer.customer_id=:customer_id