Skip to content

Commit

Permalink
- Blackened the code
Browse files Browse the repository at this point in the history
  • Loading branch information
regiellis committed Sep 20, 2024
1 parent a7fc023 commit ff2a88f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 28 deletions.
2 changes: 1 addition & 1 deletion invokeai_presets_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def styles_import_command(
help="The type of preset to import, either 'user' or 'project'. Default is 'user'",
show_default="False",
),
]=False
] = False
):
import_presets(project_type)

Expand Down
84 changes: 65 additions & 19 deletions invokeai_presets_cli/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def get_db(connection: bool) -> Any:


# ANCHOR: PRESET FUNCTIONS START
def get_presets_list(show_defaults: bool, show_all: bool, show_project: bool, page: int = 1, items_per_page: int = 10) -> Tuple[int, List[Dict[str, Any]]]:
def get_presets_list(
show_defaults: bool,
show_all: bool,
show_project: bool,
page: int = 1,
items_per_page: int = 10,
) -> Tuple[int, List[Dict[str, Any]]]:
db = get_db(connection=True)
base_query = "SELECT * FROM style_presets"
conditions = {
Expand All @@ -66,8 +72,10 @@ def get_presets_list(show_defaults: bool, show_all: bool, show_project: bool, pa
(True, False, False): "WHERE type = 'default'",
(False, False, True): "WHERE type = 'project'",
}
condition = conditions.get((show_defaults, show_all, show_project), "WHERE type = 'default'")

condition = conditions.get(
(show_defaults, show_all, show_project), "WHERE type = 'default'"
)

# Count total presets
count_query = f"SELECT COUNT(*) FROM style_presets {condition}".strip()
total_presets = db.execute(count_query).fetchone()[0]
Expand All @@ -79,8 +87,13 @@ def get_presets_list(show_defaults: bool, show_all: bool, show_project: bool, pa

return total_presets, presets

def get_preset_page_count(show_defaults: bool, show_all: bool, show_project: bool, items_per_page: int = 10) -> int:
total_presets, _ = get_presets_list(show_defaults, show_all, show_project, 1, items_per_page)

def get_preset_page_count(
show_defaults: bool, show_all: bool, show_project: bool, items_per_page: int = 10
) -> int:
total_presets, _ = get_presets_list(
show_defaults, show_all, show_project, 1, items_per_page
)
return math.ceil(total_presets / items_per_page)


Expand Down Expand Up @@ -164,7 +177,9 @@ def import_presets(project_type: bool) -> None:
db = get_db(connection=True)
existing_presets = {
preset[1]: preset # Tuple // name is 1st element
for preset in get_presets_list(show_defaults=False, show_all=True, show_project=False)
for preset in get_presets_list(
show_defaults=False, show_all=True, show_project=False
)
}
presets_to_update = []
presets_to_create = []
Expand Down Expand Up @@ -272,7 +287,9 @@ def convert_preset_format(preset: Dict[str, Any], project_type) -> Dict[str, Any
# Convert from the new format to the database format
return {
"name": preset["name"],
"type": "project" if project_type else preset.get("type", "user"), # Default to 'user' if not specified
"type": (
"project" if project_type else preset.get("type", "user")
), # Default to 'user' if not specified
"preset_data": {
"positive_prompt": preset.get("positive_prompt", preset.get("prompt", "")),
"negative_prompt": preset.get("negative_prompt", ""),
Expand Down Expand Up @@ -313,8 +330,16 @@ def validate_preset(preset: Dict[str, Any]) -> bool:
return True


def display_presets(show_defaults: bool, show_all: bool, show_project: bool, page: int = 1, items_per_page: int = 10) -> None:
total_presets, presets = get_presets_list(show_defaults, show_all, show_project, page, items_per_page)
def display_presets(
show_defaults: bool,
show_all: bool,
show_project: bool,
page: int = 1,
items_per_page: int = 10,
) -> None:
total_presets, presets = get_presets_list(
show_defaults, show_all, show_project, page, items_per_page
)
total_pages = math.ceil(total_presets / items_per_page)

presets_table = create_table(
Expand All @@ -323,7 +348,17 @@ def display_presets(show_defaults: bool, show_all: bool, show_project: bool, pag
)

if not presets:
types = ' or '.join(t for t, f in zip(['default', 'all', 'project'], [show_defaults, show_all, show_project]) if f) or 'user'
types = (
" or ".join(
t
for t, f in zip(
["default", "all", "project"],
[show_defaults, show_all, show_project],
)
if f
)
or "user"
)
feedback_message(f"No presets found for {types}", "warning")
return

Expand All @@ -342,16 +377,23 @@ def display_presets(show_defaults: bool, show_all: bool, show_project: bool, pag

if total_pages > 1:
while True:
choice = typer.prompt("Enter 'n' for next page, 'p' for previous page, or 'q' to quit", default="q")
if choice.lower() == 'n' and page < total_pages:
choice = typer.prompt(
"Enter 'n' for next page, 'p' for previous page, or 'q' to quit",
default="q",
)
if choice.lower() == "n" and page < total_pages:
page += 1
display_presets(show_defaults, show_all, show_project, page, items_per_page)
display_presets(
show_defaults, show_all, show_project, page, items_per_page
)
break
elif choice.lower() == 'p' and page > 1:
elif choice.lower() == "p" and page > 1:
page -= 1
display_presets(show_defaults, show_all, show_project, page, items_per_page)
display_presets(
show_defaults, show_all, show_project, page, items_per_page
)
break
elif choice.lower() == 'q':
elif choice.lower() == "q":
break
else:
console.print("Invalid choice. Please try again.")
Expand Down Expand Up @@ -431,7 +473,9 @@ def delete_presets() -> None:
presets_to_delete = []

if delete_source == "Select from list":
all_presets = get_presets_list(show_defaults=False, show_all=True, show_project=False)
all_presets = get_presets_list(
show_defaults=False, show_all=True, show_project=False
)
# TODO: Refactor ...
choices = [f"{preset[1]} (ID: {preset[0]})" for preset in all_presets]
questions = [
Expand Down Expand Up @@ -477,8 +521,10 @@ def delete_presets() -> None:
)
return

#TODO: This is sticky...need to refactor
user_presets = get_presets_list(show_defaults=False, show_all=False, show_project=False)
# TODO: This is sticky...need to refactor
user_presets = get_presets_list(
show_defaults=False, show_all=False, show_project=False
)

presets_to_delete = [
# TODO: This brittle ASF...need to stop using tuples for this or start unpacking the, but for now, it works
Expand Down
28 changes: 20 additions & 8 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,41 +98,53 @@ def test_list_presets(runner, mock_db):
simplified_output = simplify_rich_output(result.stdout)
print(f"List presets output: {simplified_output}") # Debug print
assert result.exit_code == 0
assert any(keyword in simplified_output for keyword in ["ID", "Name", "Prompts", "No presets found"])
assert any(
keyword in simplified_output
for keyword in ["ID", "Name", "Prompts", "No presets found"]
)

if "No presets found" in simplified_output:
assert "Warning" in simplified_output
return # Skip the rest of the tests if no presets are found

# Only continue with these tests if presets were found

# Test with pagination options
result = runner.invoke(invoke_presets_cli, ["list", "--page", "2", "--items-per-page", "5"])
result = runner.invoke(
invoke_presets_cli, ["list", "--page", "2", "--items-per-page", "5"]
)
simplified_output = simplify_rich_output(result.stdout)
print(f"List presets with pagination output: {simplified_output}") # Debug print
assert result.exit_code == 0
assert "Page" in simplified_output

# Test with other options
result = runner.invoke(invoke_presets_cli, ["list", "--only-defaults", "--page", "1", "--items-per-page", "10"])
result = runner.invoke(
invoke_presets_cli,
["list", "--only-defaults", "--page", "1", "--items-per-page", "10"],
)
simplified_output = simplify_rich_output(result.stdout)
print(f"List default presets output: {simplified_output}") # Debug print
assert result.exit_code == 0
assert "Page" in simplified_output

# Test navigation
with patch('builtins.input', side_effect=['n', 'p', 'q']):
with patch("builtins.input", side_effect=["n", "p", "q"]):
result = runner.invoke(invoke_presets_cli, ["list"])
simplified_output = simplify_rich_output(result.stdout)
print(f"List presets with navigation output: {simplified_output}") # Debug print
print(
f"List presets with navigation output: {simplified_output}"
) # Debug print
assert result.exit_code == 0
assert "Page" in simplified_output

# Test invalid navigation input
with patch('builtins.input', side_effect=['x', 'q']):
with patch("builtins.input", side_effect=["x", "q"]):
result = runner.invoke(invoke_presets_cli, ["list"])
simplified_output = simplify_rich_output(result.stdout)
print(f"List presets with invalid navigation output: {simplified_output}") # Debug print
print(
f"List presets with invalid navigation output: {simplified_output}"
) # Debug print
assert result.exit_code == 0
assert "Invalid choice" in simplified_output or "Page" in simplified_output

Expand Down

0 comments on commit ff2a88f

Please sign in to comment.