Skip to content

Commit

Permalink
Add resolution mgmt to the master
Browse files Browse the repository at this point in the history
Signed-off-by: bghira <bghira@users.github.com>
  • Loading branch information
bghira committed Apr 10, 2023
1 parent fc567d3 commit 6cd0492
Show file tree
Hide file tree
Showing 24 changed files with 740 additions and 160 deletions.
24 changes: 12 additions & 12 deletions discord_tron_master/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
logging.basicConfig(level=logging.INFO)
from discord_tron_master.classes import log_format

from discord_tron_master.classes.database_handler import DatabaseHandler
from discord_tron_master.api import API
Expand All @@ -12,6 +12,7 @@
from discord_tron_master.classes.command_processor import CommandProcessor

config = AppConfig()

api = API()
from discord_tron_master.auth import Auth
auth = Auth()
Expand Down Expand Up @@ -66,7 +67,6 @@ def run_discord_bot():
with ThreadPoolExecutor(max_workers=3) as executor:
tasks = [
executor.submit(run_flask_api),
# executor.submit(run_websocket_hub),
executor.submit(run_discord_bot),
]

Expand Down Expand Up @@ -122,35 +122,35 @@ def create_client_tokens(username: str):
# Do we have a user at all?
existing_user = User.query.filter_by(username=username).first()
if existing_user is None:
print(f"User {username} does not exist")
logging.info(f"User {username} does not exist")
return
# Does it have a client?
client = existing_user.has_client()
if not client:
print(f"Client does not exist for user {existing_user.username} - we will try to create one.")
logging.info(f"Client does not exist for user {existing_user.username} - we will try to create one.")
client = existing_user.create_client()
else:
print(f"User already had an OAuth Client registered. Using that: {client}")
logging.info(f"User already had an OAuth Client registered. Using that: {client}")
# Did we deploy them an API Key?
print("Checking for API Key...")
logging.info("Checking for API Key...")
api_key = ApiKey.query.filter_by(client_id=client.client_id, user_id=client.user_id).first()
if api_key is None:
print("No API Key found, generating one...")
logging.info("No API Key found, generating one...")
api_key = ApiKey.generate_by_user_id(existing_user.id)
print(f"API key for client/user:\n" + json.dumps(api_key.to_dict(), indent=4))
logging.info(f"API key for client/user:\n" + json.dumps(api_key.to_dict(), indent=4))
# Do we have tokens for this user?
print("Checking for existing tokens...")
logging.info("Checking for existing tokens...")
existing_tokens = OAuthToken.query.filter_by(user_id=existing_user.id).first()
if existing_tokens is not None:
print(f"Tokens already exist for user {username}:\n" + json.dumps(existing_tokens.to_dict(), indent=4))
logging.info(f"Tokens already exist for user {username}:\n" + json.dumps(existing_tokens.to_dict(), indent=4))
return existing_tokens
# It seems like we can proceed.
print(f"Creating tokens for user {username}")
logging.info(f"Creating tokens for user {username}")
host = config.get_websocket_hub_host()
port = config.get_websocket_hub_port()
tls = config.get_websocket_hub_tls()
with api.app.app_context():
refresh_token = auth.create_refresh_token(client.client_id, user_id=existing_user.id)
refresh_token = auth.create_refresh_token(client.client_id, existing_user.id)
protocol = "ws"
if tls:
protocol = "wss"
Expand Down
21 changes: 14 additions & 7 deletions discord_tron_master/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

class API:
def __init__(self):
print("Loaded Flask API")
config = AppConfig()
logging.debug("Loaded Flask API")
self.config = AppConfig()
self.app = Flask(__name__)
database_handler = DatabaseHandler(self.app, config)
AppConfig.set_flask(self.app)
database_handler = DatabaseHandler(self.app, self.config)
self.db = database_handler.db
from discord_tron_master.models.transformers import Transformers
self.migrate = Migrate(self.app, self.db)
self.register_routes()
self.auth = None
Expand All @@ -22,7 +24,12 @@ def add_resource(self, resource, route):
self.api.add_resource(resource, route)

def run(self, host='0.0.0.0', port=5000):
self.app.run(host=host, port=port)
import ssl
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(self.config.project_root + '/config/server_cert.pem', self.config.project_root + '/config/server_key.pem')
# Set the correct SSL/TLS version (You can change PROTOCOL_TLS to the appropriate version if needed)
ssl_context.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
self.app.run(host=host, port=port, ssl_context=ssl_context)

def set_auth(self, auth):
self.auth = auth
Expand All @@ -31,23 +38,23 @@ def register_routes(self):
# assuming you have 'app' defined as your Flask instance
@self.app.route("/refresh_token", methods=["POST"])
def refresh_token():
print("refresh_token endpoint hit")
logging.debug("refresh_token endpoint hit")
refresh_token = request.json.get("refresh_token")
if not refresh_token:
return jsonify({"error": "refresh_token is required"}), 400
from discord_tron_master.models import OAuthToken
token_data = OAuthToken.query.filter_by(refresh_token=refresh_token).first()
if not token_data:
return jsonify({"error": "Invalid refresh token"}), 400
print(f"Refreshed access token requested from {token_data.client_id}")
logging.debug(f"Refreshed access token requested from {token_data.client_id}")
# Logic to refresh the access token using the provided refresh_token
new_ticket = self.auth.refresh_access_token(token_data)
response = new_ticket.to_dict()

return jsonify(response)
@self.app.route("/authorize", methods=["POST"])
def authorize():
print("authorize endpoint hit")
logging.debug("authorize endpoint hit")
client_id = request.json.get("client_id")
api_key = request.json.get("api_key")

Expand Down
27 changes: 12 additions & 15 deletions discord_tron_master/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ def validate_api_key(self, api_key):
logging.info("Key is perpetually active.")
return True
elif key_data and key_data.expires > datetime.datetime.utcnow():
logging.info("The api key has not expired yet.")
logging.debug("The api key has not expired yet.")
return True
logging.error("API Key was Invalid: %s" % api_key)
print(f"API Key was Invalid: {key_data}")
return False

def validate_access_token(self, access_token):
Expand All @@ -69,19 +68,18 @@ def validate_access_token(self, access_token):
# We found a perpetual token. This is probably bad.
raise Exception("Token is perpetually active.")
elif token_data and (datetime.datetime.timestamp(token_data.issued_at)*1000 + token_data.expires_in) > datetime.datetime.utcnow().timestamp():
logging.info("The token has not expired yet.")
logging.debug("The token has not expired yet.")
return True
logging.error("Access token was Invalid: %s" % access_token)
print(f"Access token was Invalid: {token_data}")
return False

# As far as I can tell, this is the most important aspect.
def create_refresh_token(self, token_data, scopes=None, expires_in=None):
def create_refresh_token(self, client_id, user_id, scopes=None, expires_in=None):
import secrets
# Don't do this method if you have a token already.
token = OAuthToken.query.filter_by(client_id=token_data.client_id, user_id=token_data.user_id).first()
token = OAuthToken.query.filter_by(client_id=client_id, user_id=user_id).first()
if not token:
token = OAuthToken(token_data.client_id, token_data.user_id, scopes=scopes, expires_in=expires_in)
token = OAuthToken(client_id, user_id, scopes=scopes, expires_in=expires_in)
# Currently, we're only generating refresh_token once, at deploy.
# This is less secure, but simpler for now.
token.refresh_token = OAuthToken.make_token()
Expand All @@ -92,24 +90,23 @@ def create_refresh_token(self, token_data, scopes=None, expires_in=None):

# Refresh the auth link using the refresh_token
def refresh_authorization(self, token_data, expires_in=None):
print("Refreshing access token!")
logging.debug("Refreshing access token!")
token_data.access_token = OAuthToken.make_token()
print("Updating token for client_id: %s, user_id: %s, previous issued_at was %s" % (token_data.client_id, token_data.user_id, token_data.issued_at))
logging.debug("Updating token for client_id: %s, user_id: %s, previous issued_at was %s" % (token_data.client_id, token_data.user_id, token_data.issued_at))
token_data.set_issue_timestamp()
print("After update, access_token is now %s" % token_data.access_token)
print("After setting timestamp, issued_at is now %s" % token_data.issued_at)
logging.debug("After update, access_token is now %s" % token_data.access_token)
logging.debug("After setting timestamp, issued_at is now %s" % token_data.issued_at)
db.session.add(token_data)
db.session.commit()
return token_data

# An existing access_token can be updated.
def refresh_access_token(self, token_data, expires_in=None):
print("Refreshing access token!")
logging.debug("Refreshing access token!")
token_data.access_token = OAuthToken.make_token()
print("Updating token for client_id: %s, user_id: %s, previous issued_at was %s" % (token_data.client_id, token_data.user_id, token_data.issued_at))
logging.debug("Updating token for client_id: %s, user_id: %s, previous issued_at was %s" % (token_data.client_id, token_data.user_id, token_data.issued_at))
token_data.set_issue_timestamp()
print("After update, access_token is now %s" % token_data.access_token)
print("After setting timestamp, issued_at is now %s" % token_data.issued_at)
logging.debug("After setting timestamp, issued_at is now %s" % token_data.issued_at)
db.session.add(token_data)
db.session.commit()
return token_data
15 changes: 9 additions & 6 deletions discord_tron_master/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from discord_tron_master.websocket_hub import WebSocketHub
from discord_tron_master.classes.queue_manager import QueueManager
from discord_tron_master.classes.worker_manager import WorkerManager
from discord_tron_master.classes.custom_help import CustomHelp
from discord_tron_master.classes.app_config import AppConfig
config = AppConfig()

class DiscordBot:
discord_instance = None
Expand All @@ -18,7 +21,7 @@ def __init__(self, token):
intents.members = True
intents.message_content = True
intents.presences = True
self.bot = commands.Bot(command_prefix="!", intents=intents)
self.bot = commands.Bot(command_prefix=config.get_command_prefix(), intents=intents, help_command=CustomHelp())
DiscordBot.discord_instance = self

@classmethod
Expand Down Expand Up @@ -46,9 +49,9 @@ async def run(self):

async def load_cogs(self, cogs_path="discord_tron_master/cogs"):
import logging
logging.info("Loading cogs! Path: " + cogs_path)
logging.debug("Loading cogs! Path: " + cogs_path)
for root, _, files in os.walk(cogs_path):
logging.info("Found cogs: " + str(files))
logging.debug("Found cogs: " + str(files))
for file in files:
if file.endswith(".py"):
cog_path = os.path.join(root, file).replace("/", ".").replace("\\", ".")[:-3]
Expand All @@ -57,10 +60,10 @@ async def load_cogs(self, cogs_path="discord_tron_master/cogs"):
cog_module = importlib.import_module(cog_path)
cog_class_name = getattr(cog_module, file[:-3].capitalize())
await self.bot.add_cog(cog_class_name(self.bot))
logging.info(f"Loaded cog: {cog_path}")
logging.debug(f"Loaded cog: {cog_path}")
except Exception as e:
logging.info(f"Failed to load cog: {cog_path}")
logging.info(e)
logging.error(f"Failed to load cog: {cog_path}")
logging.error(e)

async def find_channel(self, channel_id):
for guild in self.bot.guilds:
Expand Down
73 changes: 52 additions & 21 deletions discord_tron_master/classes/app_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import os
import json, logging, os
from pathlib import Path

DEFAULT_CONFIG = {
Expand Down Expand Up @@ -29,9 +28,12 @@
}

DEFAULT_USER_CONFIG = {
"steps": 100,
"temperature": 0.9,
"strength": 0.5,
"model": "theintuitiveye/HARDblend",
"variation_model": "HARDblend",
"negative_prompt": "(child, teen) (malformed, malignant)",
"steps": 100,
"positive_prompt": "(beautiful, unreal engine 5, highly detailed, hyperrealistic)",
"resolution": {
"width": 512,
Expand All @@ -40,23 +42,18 @@
}

class AppConfig:
flask = None
def __init__(self):
parent = os.path.dirname(Path(__file__).resolve().parent)
self.project_root = parent
config_path = os.path.join(parent, "config")
self.config_path = os.path.join(config_path, "config.json")
self.example_config_path = os.path.join(config_path, "example.json")
self.reload_config()

if not os.path.exists(self.config_path):
with open(self.example_config_path, "r") as example_file:
example_config = json.load(example_file)

with open(self.config_path, "w") as config_file:
json.dump(example_config, config_file)

with open(self.config_path, "r") as config_file:
self.config = json.load(config_file)

self.config = self.merge_dicts(DEFAULT_CONFIG, self.config)
@classmethod
def set_flask(cls, flask):
cls.flask = flask

@staticmethod
def merge_dicts(dict1, dict2):
Expand All @@ -68,7 +65,24 @@ def merge_dicts(dict1, dict2):
result[key] = value
return result

def reload_config(self):
if not os.path.exists(self.config_path):
with open(self.example_config_path, "r") as example_file:
example_config = json.load(example_file)
with open(self.config_path, "w") as config_file:
json.dump(example_config, config_file, indent=4)
with open(self.config_path, "r") as config_file:
self.config = json.load(config_file)
self.config = self.merge_dicts(DEFAULT_CONFIG, self.config)

def get_log_level(self):
self.reload_config()
level = self.config.get("log_level", "INFO")
result = getattr(logging, level.upper(), "ERROR")
return result

def get_user_config(self, user_id):
self.reload_config()
user_config = self.config.get("users", {}).get(str(user_id), {})
return self.merge_dicts(DEFAULT_USER_CONFIG, user_config)

Expand All @@ -83,47 +97,64 @@ def merge_dicts(dict1, dict2):
return result

def get_concurrent_slots(self):
self.reload_config()
return self.config.get("concurrent_slots", 1)

def get_command_prefix(self):
return self.config.get("cmd_prefix", "+")
self.reload_config()
return self.config.get("cmd_prefix")

def get_websocket_hub_host(self):
self.reload_config()
return self.config.get("websocket_hub", {}).get("host", "localhost")

def get_websocket_hub_port(self):
self.reload_config()
return self.config.get("websocket_hub", {}).get("port", 6789)

def get_websocket_hub_tls(self):
self.reload_config()
return self.config.get("websocket_hub", {}).get("tls", False)

def get_huggingface_api_key(self):
self.reload_config()
return self.config["huggingface_api"].get("api_key", None)

def get_discord_api_key(self):
self.reload_config()
return self.config.get("discord", {}).get("api_key", None)

def get_local_model_path(self):
return self.config["huggingface"].get("local_model_path", None)
self.reload_config()
return self.config.get("huggingface", {}).get("local_model_path", "/root/.cache/huggingface/hub")

def set_user_config(self, user_id, user_config):
self.config.get("users", {})[str(user_id)] = user_config
with open(self.config_path, "w") as config_file:
json.dump(self.config, config_file)
logging.info(f"Saving config: {self.config}")
json.dump(self.config, config_file, indent=4)

def set_user_setting(self, user_id, setting_key, value):
user_id = str(user_id)
self.config.get("users", {}).get(user_id, {})[setting_key] = value
with open(self.config_path, "w") as config_file:
json.dump(self.config, config_file)
user_config = self.get_user_config(user_id)
user_config[setting_key] = value
self.set_user_config(user_id, user_config)

def get_user_setting(self, user_id, setting_key, default_value=None):
self.reload_config()
user_id = str(user_id)
return self.config.get("users", {}).get(user_id, {}).get(setting_key, default_value)
user_config = self.get_user_config(user_id)
return user_config.get(setting_key, default_value)

def get_mysql_user(self):
self.reload_config()
return self.config.get("mysql", {}).get("user", "diffusion")
def get_mysql_password(self):
self.reload_config()
return self.config.get("mysql", {}).get("password", "diffusion_pwd")
def get_mysql_hostname(self):
self.reload_config()
return self.config.get("mysql", {}).get("hostname", "localhost")
def get_mysql_dbname(self):
self.reload_config()
return self.config.get("mysql", {}).get("dbname", "diffusion_master")
Loading

0 comments on commit 6cd0492

Please sign in to comment.