Skip to content

Commit

Permalink
- Fix for import presets // local and file
Browse files Browse the repository at this point in the history
  • Loading branch information
regiellis committed Sep 19, 2024
1 parent 724b3b8 commit ee5d630
Showing 1 changed file with 72 additions and 43 deletions.
115 changes: 72 additions & 43 deletions invokeai_presets_cli/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def import_presets() -> None:
console.print("Import cancelled.")
return

presets_to_import = []

if source == "Local File":
file_path = inquirer.text(message="Enter the path to the JSON file")
try:
Expand Down Expand Up @@ -107,15 +109,48 @@ def import_presets() -> None:
)
return

# Ask user if they want to select presets or import all
import_choice = inquirer.list_input(
"How would you like to proceed?",
choices=["Select Presets", "Import All", "Cancel"],
)

if import_choice == "Cancel":
console.print("Import cancelled.")
return

if import_choice == "Select Presets":
choices = [
inquirer.Checkbox(
"selected_presets",
message="Select presets to import",
choices=[
preset.get("name", f"Unnamed Preset {i}")
for i, preset in enumerate(presets_to_import)
],
)
]
answers = inquirer.prompt(choices)
if not answers or not answers["selected_presets"]:
console.print("No presets selected. Import cancelled.")
return
selected_presets = [
preset
for preset in presets_to_import
if preset.get("name") in answers["selected_presets"]
]
else: # Import All
selected_presets = presets_to_import

db = get_db(connection=True)
existing_presets = {
preset["name"]: preset
for preset in get_presets_list(show_defaults=True, show_all=True)
preset[1]: preset # Tuple // name is 1st element
for preset in get_presets_list(show_defaults=False, show_all=True)
}
presets_to_update = []
presets_to_create = []

for preset in presets_to_import:
for preset in selected_presets:
if not validate_preset(preset):
console.print(
f"[yellow]Skipping invalid preset: {preset.get('name', 'Unknown')}[/yellow]"
Expand Down Expand Up @@ -160,9 +195,9 @@ def import_presets() -> None:

# Perform database operations
try:
with db.conn:
with db:
# This automatically manages transactions
cursor = db.conn.cursor()
cursor = db.cursor()
# Disable triggers temporarily
cursor.execute("PRAGMA recursive_triggers = OFF;")

Expand Down Expand Up @@ -220,7 +255,7 @@ def convert_preset_format(preset: Dict[str, Any]) -> Dict[str, Any]:
"name": preset["name"],
"type": preset.get("type", "user"), # Default to 'user' if not specified
"preset_data": {
"prompt": preset.get("prompt", ""),
"positive_prompt": preset.get("positive_prompt", preset.get("prompt", "")),
"negative_prompt": preset.get("negative_prompt", ""),
},
}
Expand All @@ -231,28 +266,30 @@ def validate_preset(preset: Dict[str, Any]) -> bool:
return False

if "preset_data" in preset:
# Validate and update the old structure
# Validate the old structure
if not isinstance(preset["preset_data"], dict):
return False
preset["preset_data"]["prompt"] = preset["preset_data"].get("prompt", "")
preset["preset_data"]["negative_prompt"] = preset["preset_data"].get(
"negative_prompt", ""
)
elif "prompt" in preset:
# Validate and update the new structure
preset["negative_prompt"] = preset.get("negative_prompt", "")
# Convert to old structure
preset["preset_data"] = {
"prompt": preset["prompt"],
"negative_prompt": preset["negative_prompt"],
}
del preset["prompt"]
del preset["negative_prompt"]
# Ensure positive_prompt and negative_prompt exist, but don't add if not present
if (
"prompt" in preset["preset_data"]
and "positive_prompt" not in preset["preset_data"]
):
preset["preset_data"]["positive_prompt"] = preset["preset_data"].pop(
"prompt"
)
preset["preset_data"].setdefault("positive_prompt", "")
preset["preset_data"].setdefault("negative_prompt", "")
elif "positive_prompt" in preset or "prompt" in preset:
# Validate the new structure
if "prompt" in preset and "positive_prompt" not in preset:
preset["positive_prompt"] = preset.pop("prompt")
preset.setdefault("positive_prompt", "")
preset.setdefault("negative_prompt", "")
else:
return False

# Ensure 'type' is present
preset["type"] = preset.get("type", "user")
preset.setdefault("type", "user")

return True

Expand Down Expand Up @@ -441,24 +478,14 @@ def create_snapshot() -> None:
snapshot_path = os.path.join(SNAPSHOTS_DIR, snapshot_name)

try:
with Progress() as progress:
task = progress.add_task("[green]Creating snapshot...", total=100)

# Use SQLite backup API with progress feedback
with (
get_db(connection=True) as source_conn,
sqlite3.connect(snapshot_path) as dest_conn,
):
source = source_conn.cursor()
source.execute("SELECT count(*) FROM sqlite_master")
total_objects = source.fetchone()[0]
console.print("[green]Creating snapshot...[/green]")

def progress_callback(status, remaining, total):
progress.update(
task, completed=int((total - remaining) / total * 100)
)

source_conn.backup(dest_conn, pages=1, progress=progress_callback)
# Use SQLite backup API
with (
get_db(connection=True) as source_conn,
sqlite3.connect(snapshot_path) as dest_conn,
):
source_conn.backup(dest_conn)

snapshots = load_snapshots()
snapshots.append({"name": snapshot_name, "timestamp": timestamp})
Expand All @@ -468,14 +495,16 @@ def progress_callback(status, remaining, total):
old_snapshot_path = os.path.join(SNAPSHOTS_DIR, oldest_snapshot["name"])
if os.path.exists(old_snapshot_path):
os.remove(old_snapshot_path)
console.print(f"Removed oldest snapshot: {oldest_snapshot['name']}")
feedback_message(
f"Removed oldest snapshot: {oldest_snapshot['name']}", "info"
)

save_snapshots(snapshots)
console.print(f"[green]Snapshot created successfully: {snapshot_name}[/green]")
feedback_message(f"Created snapshot: {snapshot_name}", "success")
except sqlite3.Error as e:
console.print(f"[bold red]SQLite Error creating snapshot:[/bold red] {str(e)}")
feedback_message(f"Error creating snapshot: {str(e)}", "error")
except Exception as e:
console.print(f"[bold red]Error creating snapshot:[/bold red] {str(e)}")
feedback_message(f"Error creating snapshot: {str(e)}", "error")


def load_snapshots() -> List[Dict[str, str]]:
Expand Down

0 comments on commit ee5d630

Please sign in to comment.