Extracting tables from SQL by using Sqlfluff

I got an assignment the other day to produce documentation to send to a customer. The extraction of the table names required to execute a certain Databricks notebook was part of the task. The plan was to build an object dependency tree.

The query spanned 279 lines. How can you extract only the table names from a file without having to manually look for them? Can we make use of this technique again in the future?

Sqlfluff to the rescue

You can employ it to format/beautify your SQL code, discover bugs in code before executing it, and interpret queries. It’s also advantageous to establish a uniform style guide across your team for how the queries should look.

It’s such a great tool. Can’t recommend it enough. Have a quick look on the their website to see how to do it.

Parsing the SQL Notebook

I won’t take any credits here, most of the code was stolen from this post by Ben Chuanlong Du who works at Google.

Just had to fix some minor details and extract the relevant parts for the task at hand.

Pre-Requisites

You’ll need Python3 installed and also sqlfluff:

pip3 install sqlfluff

After this step you can open a command line/powershell/bash and execute:

sqlfluff parse databricks_notebook.sql > query.log

Remember that you can export Databricks notebooks from gui. Check the Databricks documentation on how to do it.

The outputed file will contain a parsing tree of the query with all the interesting data.

The gist of it

The idea is quite simple, we just read the file and collect all the data where the identifier is a table_reference using regular expressions to find them. Then we just print out the set.

import re
from pathlib import Path
from typing import Set

def extract_identifier(line):
    identifier = line.strip().rsplit(maxsplit=1)[-1]
    return identifier[1:-1]


def extract_reference(lines, idx):
    if "dot:" in lines[idx + 2]:
        return extract_identifier(lines[idx + 1]), extract_identifier(lines[idx + 3])
    return "", extract_identifier(lines[idx + 1])


def main():
    with Path("./query.log").open() as fin:
        lines = fin.readlines()

    tab_refs = [
        extract_reference(lines, idx)
        for idx, line in enumerate(lines)
        if re.search("\|\s*table_reference:", line)
    ]

    tables = set(table for _, table in tab_refs if table)

    print(tables)

if __name__ == "__main__":
    main()

And our example output will be something like this:

{
    'Table_0',
    'Table_1',
    'Table_2',
    'Table_3',
    'Table_4',
    ....
    'Table_N'
}

Conclusion

So here’s a quick and dirty way of extracting useful parts of queries. This may be quite helpful for producing documentation or to determining dependencies between your queries/notebooks. A huge timesaver. Many many thanks to Ben Chuanlong Du who decided to share his knowledge with the world.

Have fun!

 Share!