Skip to content

Commit

Permalink
Merge pull request #6 from wharton/feature/reserved-words
Browse files Browse the repository at this point in the history
Simplify treatment of reserved words to follow Python convention.
  • Loading branch information
FlipperPA authored Apr 3, 2024
2 parents 6597e7f + 9de8d66 commit f9c0319
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 26 deletions.
49 changes: 25 additions & 24 deletions automagic_rest/management/commands/build_data_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from collections import namedtuple

from glob import glob
import keyword
import os
from re import sub

from django.core.management.base import BaseCommand
from django.db import connections
from django.template.loader import render_to_string

from automagic_rest.settings import get_reserved_words_to_append_underscore

# Map PostgreSQL column types to Django ORM field type
# Please note: "blank=True, null=True" must be typed
# exactly, as it will be stripped out for primary keys
Expand Down Expand Up @@ -37,15 +38,8 @@
"jsonb": "JSONField({}blank=True, null=True{})",
}

# Created a reserved words list that can not be used for Django field
# names. Start with the Python reserved words list, and add any additional
# fields reserved by DRF or Automagic REST.
# We will then append `_var` to any fields with these names, and map to
# the underlying database column in the models.
RESERVED_WORDS = keyword.kwlist

# Additional reserved words for Django REST Framework
RESERVED_WORDS.append("format")
# Words that can't be used as column names
RESERVED_WORDS = get_reserved_words_to_append_underscore()


def fetch_result_with_blank_row(cursor):
Expand All @@ -54,7 +48,9 @@ def fetch_result_with_blank_row(cursor):
model and column are written in the loop.
"""
results = cursor.fetchall()
results.append(("__BLANK__", "__BLANK__", "__BLANK__", "integer", "__BLANK__", 0, 0))
results.append(
("__BLANK__", "__BLANK__", "__BLANK__", "integer", "__BLANK__", 0, 0)
)
desc = cursor.description
nt_result = namedtuple("Result", [col[0] for col in desc])

Expand All @@ -81,7 +77,10 @@ def add_arguments(self, parser):
action="store",
dest="owner",
default="my_pg_user",
help='Select schemata from this PostgreSQL owner user. Defaults to the "wrdsadmn" owner.',
help=(
'Select schemata from this PostgreSQL owner user. Defaults to the '
'"wrdsadmn" owner.'
),
)
parser.add_argument(
"--path",
Expand All @@ -95,7 +94,10 @@ def add_arguments(self, parser):
action="store_true",
dest="verbose",
default=False,
help="""Sets verbose mode; displays each model built, instead of just schemata.""",
help=(
"Sets verbose mode; displays each model built, instead of just "
"schemata."
),
)

def get_db(self, options):
Expand Down Expand Up @@ -239,7 +241,7 @@ def write_schema_files(self, root_path, context):
with open(
f"""{root_path}/models/{context["schema_name"]}.py""", "w"
) as f:
output = render_to_string(f"automagic_rest/models.html", context)
output = render_to_string("automagic_rest/models.html", context)
f.write(output)

def handle(self, *args, **options):
Expand Down Expand Up @@ -301,12 +303,8 @@ def handle(self, *args, **options):

# If the column name is a Python reserved word, append an underscore
# to follow the Python convention
if row.column_name in RESERVED_WORDS or row.column_name.endswith("_"):
if row.column_name.endswith("_"):
under_score = ""
else:
under_score = "_"
column_name = "{}{}var".format(row.column_name, under_score)
if row.column_name in RESERVED_WORDS:
column_name = f"{row.column_name}_"
db_column = ", db_column='{}'".format(row.column_name)
else:
column_name = row.column_name
Expand All @@ -322,14 +320,17 @@ def handle(self, *args, **options):
if decimal_places is None:
decimal_places = decimal_places_default

db_column += f", max_digits={max_digits}, decimal_places={decimal_places}"
db_column += (
f", max_digits={max_digits}, decimal_places={decimal_places}"
)

if row.data_type in COLUMN_FIELD_MAP:
if primary_key_has_been_set:
field_map = COLUMN_FIELD_MAP[row.data_type].format("", db_column)
else:
# We'll make the first column the primary key, since once is required in the Django ORM
# and this is read-only. Primary keys can not be set to NULL in Django.
# We'll make the first column the primary key, since once is
# required in the Django ORM and this is read-only. Primary keys can
# not be set to NULL in Django.
field_map = (
COLUMN_FIELD_MAP[row.data_type]
.format("primary_key=True", db_column)
Expand All @@ -349,5 +350,5 @@ def handle(self, *args, **options):
# Pop off the final false row, and write the URLs file.
context["routes"].pop()
with open(f"{root_path}/urls.py", "w") as f:
output = render_to_string(f"automagic_rest/urls.html", context)
output = render_to_string("automagic_rest/urls.html", context)
f.write(output)
16 changes: 16 additions & 0 deletions automagic_rest/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import keyword


def get_reserved_words_to_append_underscore():
"""
A list of reserved words list that can not be used for Django field names. This
includes the Python reserved words list, and additional fields not allowed by
Django REST Framework.
We will append `_var` to the model field names and map to the underlying database
column in the models in the code generator.
"""
reserved_words = keyword.kwlist
reserved_words.append("format")

return reserved_words
6 changes: 4 additions & 2 deletions automagic_rest/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)

from .pagination import estimate_count, CountEstimatePagination
from .settings import get_reserved_words_to_append_underscore


def split_basename(basename):
Expand Down Expand Up @@ -234,10 +235,10 @@ def get_positions(self):
Return a dict of keyed column names and their ordinal positions as values.
"""

RESERVED_WORDS = get_reserved_words_to_append_underscore()
positions = {}

cursor = connections[self.db_name].cursor()

cursor.execute(
"""
SELECT column_name, ordinal_position
Expand All @@ -250,6 +251,7 @@ def get_positions(self):

for row in cursor.fetchall():
# 0 = column_name, 1 = ordinal_position
positions[row[0]] = row[1]
column_name = f"{row[0]}_" if row[0] in RESERVED_WORDS else row[0]
positions[column_name] = row[1]

return positions

0 comments on commit f9c0319

Please sign in to comment.