Skip to content

Commit

Permalink
GPT token tracking and conversation history
Browse files Browse the repository at this point in the history
Signed-off-by: bghira <bghira@users.github.com>
  • Loading branch information
bghira committed Apr 15, 2023
1 parent d8a48c0 commit dc01517
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 36 deletions.
1 change: 1 addition & 0 deletions discord_tron_master/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self):
AppConfig.set_flask(self.app)
database_handler = DatabaseHandler(self.app, self.config)
self.db = database_handler.db
from discord_tron_master.models.conversation import Conversations
from discord_tron_master.models.transformers import Transformers
self.migrate = Migrate(self.app, self.db)
self.register_routes()
Expand Down
148 changes: 148 additions & 0 deletions discord_tron_master/classes/openai/chat_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# A class for managing ChatML histories via Flask DB.
from discord_tron_master.classes.app_config import AppConfig
from discord_tron_master.models.conversation import Conversations
from discord_tron_master.classes.openai.tokens import TokenTester

import json, logging

config = AppConfig()
app = AppConfig.flask

if app is None:
raise Exception("Flask app is not initialized.")

class ChatML:
def __init__(self, conversation: Conversations, token_limit: int = 2048):
self.conversations = conversation
self.user_id = conversation.owner
self.history = conversation.get_history(self.user_id) or Conversations.get_new_history()
self.user_config = config.get_user_config(self.user_id)
# Pick up their current role from their profile.
self.role = self.user_config["gpt_role"]
self.reply = {}
self.tokenizer = TokenTester()
self.token_limit = token_limit

# Pick up a DB connector and store it. Create the conversation, if needed.
async def initialize_conversation(self):
with app.app_context():
self.conversations = Conversations()
self.conversation = await self.get_conversation_or_create()

async def get_conversation_or_create(self):
with app.app_context():
conversation = self.conversations.get_by_owner(self.user_id)
logging.debug(f"Picked up conversation from db: {conversation}")
if conversation is None:
conversation = self.conversations.create(self.user_id, self.role, self.history)
return conversation

async def validate_reply(self):
# If we are too long, maybe we can clean it up.
logging.debug(f"Validating reply")
if await self.is_reply_too_long():
# let's clean up until it does fit.
logging.debug(f"Eureka! We can enter Alzheimers mode.")
await self.remove_history_until_reply_fits()
return True

# See if we can fit everything without emptying the history too far.
async def can_new_reply_fit_without_emptying_everything_from_history(self):
if await self.is_history_empty() and not await self.is_reply_too_long():
logging.debug(f"History is empty and reply is short enough to fit.")
return True
# If we are not empty, we can fit the reply if the history is short enough.
if await self.get_history_token_count() < self.token_limit:
logging.debug(f"History is short enough to fit.")
return True
logging.debug(f"Returning true by default. Maybe this should be a false..")
return True

# Loop over the history and remove items until the new reply will fit with the current text.
async def remove_history_until_reply_fits(self):
logging.debug(f"Stripping conversation back until the reply fits.")
while await self.is_reply_too_long() and len(await self.get_history()) > 0:
logging.debug(f"Reply is too long. Removing oldest history item.")
await self.remove_oldest_history_item()
logging.debug(f"Cleanup is complete. Returning newly pruned history.")
return await self.get_history()

# Remove the oldest history item and return the new history.
async def remove_oldest_history_item(self):
conversation = await self.get_conversation_or_create()
item = conversation.history.pop(0)
logging.debug(f"Removing oldest history item: {item}")
with app.app_context():
Conversations.set_history(self.user_id, conversation.history)
return Conversations.get_history(owner=self.user_id)

# Look at the actual token counts of each item and compare against our limit.
async def is_reply_too_long(self):
reply_token_count = await self.get_reply_token_count()
history_token_count = await self.get_history_token_count()
logging.debug(f"Reply token count: {reply_token_count}")
logging.debug(f"History token count: {history_token_count}")
if reply_token_count + history_token_count > self.token_limit:
return True
return False
async def get_reply_token_count(self):
return self.tokenizer.get_token_count(json.dumps(self.reply))
async def get_history_token_count(self):
# Pad the value by 64 to accommodate for the metadata in the JSON we can't really count right here.
return self.tokenizer.get_token_count(json.dumps(await self.get_history())) + 512

# Format the history as a string for OpenAI.
async def get_prompt(self):
return json.dumps(await self.get_history())

async def get_history(self):
conversation = await self.get_conversation_or_create()
logging.debug(f"Conversation: {conversation}")
return conversation.history

async def is_history_empty(self):
history = await self.get_history()
if len(history) == 0:
return True
if history == Conversations.get_new_history():
return True
return False

async def add_user_reply(self, content: str):
return await self.add_to_history("user", content)

async def add_system_reply(self, content: str):
return await self.add_to_history("system", content)

async def add_assistant_reply(self, content: str):
return await self.add_to_history("assistant", content)

async def add_to_history(self, role: str, content: str):
# Store the reply for processing
self.reply = {"role": role, "content": content}
if not await self.validate_reply():
raise ValueError(f"I am sorry. It seems your reply would overrun the limits of reality and time. We are currently stuck at {self.token_limit} tokens, and your message used {await self.get_reply_token_count()} tokens. Please try again.")
with app.app_context():
conversation = await self.get_conversation_or_create()
conversation.history.append(self.reply)
self.conversations.set_history(self.user_id, conversation.history)
return conversation.history


def truncate_conversation_history(self, conversation_history, new_prompt, max_tokens=2048):
# Calculate tokens for new_prompt
new_prompt_token_count = self.tokenizer.get_token_count(new_prompt)
if new_prompt_token_count >= max_tokens:
raise ValueError("The new prompt alone exceeds the maximum token limit.")

# Calculate tokens for conversation_history
conversation_history_token_counts = [len(self.tokenizer.tokenize(entry)) for entry in conversation_history]
total_tokens = sum(conversation_history_token_counts) + new_prompt_token_count

# Truncate conversation history if total tokens exceed max_tokens
while total_tokens > max_tokens:
conversation_history.pop(0) # Remove the oldest entry
conversation_history_token_counts.pop(0) # Remove the oldest entry's token count
total_tokens = sum(conversation_history_token_counts) + new_prompt_token_count

return conversation_history
37 changes: 5 additions & 32 deletions discord_tron_master/classes/openai/tokens.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,13 @@
from tiktoken import Tokenizer
from tiktoken.tokenizer import Tokenizer as OpenAITokenizer
import tiktoken

text = "This is a sample text to check the token count."

tokenizer = OpenAITokenizer()
tokens = tokenizer.tokenize(text)

token_count = len(tokens)
print(f"Token count: {token_count}")

class TokenTester:
def __init__(self, engine: str = "gpt-3.5-turbo"):
self.tokenizer = tiktoken.encoding_for_model(engine)

def tokenize(self, text):
return tokenizer.tokenize(text)
return self.tokenizer.encode(text, allowed_special='all')

def get_token_count(self, text):
tokens = self.tokenize(text)
return len(tokens)

def truncate_conversation_history(conversation_history, new_prompt, max_tokens=2048):
tokenizer = OpenAITokenizer()

# Calculate tokens for new_prompt
new_prompt_tokens = tokenizer.tokenize(new_prompt)
new_prompt_token_count = len(new_prompt_tokens)

if new_prompt_token_count >= max_tokens:
raise ValueError("The new prompt alone exceeds the maximum token limit.")

# Calculate tokens for conversation_history
conversation_history_token_counts = [len(tokenizer.tokenize(entry)) for entry in conversation_history]
total_tokens = sum(conversation_history_token_counts) + new_prompt_token_count

# Truncate conversation history if total tokens exceed max_tokens
while total_tokens > max_tokens:
conversation_history.pop(0) # Remove the oldest entry
conversation_history_token_counts.pop(0) # Remove the oldest entry's token count
total_tokens = sum(conversation_history_token_counts) + new_prompt_token_count

return conversation_history
20 changes: 17 additions & 3 deletions discord_tron_master/cogs/image/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ async def on_message(self, message):
else:
# We were mentioned, but no attachments. They must want to converse.
logging.debug("Message contains no attachments. Initiating conversation.")
gpt = GPT()
response = await gpt.discord_bot_response(prompt=message.content, ctx=message)
await discord.send_large_message(message, message.author.mention + ' ' + response)
try:
gpt = GPT()
from discord_tron_master.classes.openai.chat_ml import ChatML
from discord_tron_master.models.conversation import Conversations
app = AppConfig.flask
with app.app_context():
user_conversation = Conversations.create(message.author.id, self.config.get_user_setting(message.author.id, "gpt_role"), Conversations.get_new_history())
chat_ml = ChatML(user_conversation)
await chat_ml.add_user_reply(message.content)
response = await gpt.discord_bot_response(prompt=await chat_ml.get_prompt(), ctx=message)
await chat_ml.add_assistant_reply(response)
await discord.send_large_message(message, message.author.mention + ' ' + response)
except Exception as e:
await message.channel.send(
f"{message.author.mention} I am sorry, friend. I had an error while generating text inference: {e}"
)
logging.error(f"Error generating text inference: {e}\n\nStack trace:\n{await clean_traceback(traceback.format_exc())}")
22 changes: 21 additions & 1 deletion discord_tron_master/cogs/user/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
from discord.ext import commands
from discord_tron_master.models.transformers import Transformers
from discord_tron_master.models.conversation import Conversations
from discord_tron_master.classes.text_replies import return_random as random_fact
from discord_tron_master.classes.app_config import AppConfig
import logging

config = AppConfig()
app = AppConfig.flask

class User(commands.Cog):
def __init__(self, bot):
self.bot = bot

@commands.command(name="clear", help="Clear your GPT conversation history and start again.")
async def clear_history(self, ctx):
user_id = ctx.author.id
try:
with app.app_context():
Conversations.clear_history_by_owner(owner=user_id)
await ctx.send(
f"{ctx.author.mention} Well, well, well. It is like I don't even know you anymore. Did you know {random_fact()}?"
)
except Exception as e:
logging.error("Caught error when clearing user conversation history: " + str(e))
await ctx.send(
f"{ctx.author.mention} The smoothbrain geriatric that writes my codebase did not correctly implement that method. I am sorry. Trying again will only lead to tears."
)
98 changes: 98 additions & 0 deletions discord_tron_master/models/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from .base import db
import datetime, json

class Conversations(db.Model):
__tablename__ = 'conversations'
id = db.Column(db.Integer, primary_key=True)
owner = db.Column(db.BigInteger(), unique=False, nullable=False)
role = db.Column(db.String(255), unique=True, nullable=False)
history = db.Column(db.Text(), nullable=False, default='{}')
created = db.Column(db.DateTime, nullable=False, default=db.func.now())
updated = db.Column(db.DateTime, nullable=False, default=db.func.now())

@staticmethod
def get_all():
return Conversations.query.all()

@staticmethod
def delete_all():
all = Conversations.get_all()
for conversation in all:
db.session.delete(conversation)
db.session.commit()

@staticmethod
def clear_history_by_owner(owner: int):
conversation = Conversations.get_by_owner(owner)
import logging
logging.debug(f"Conversation before clearing: {conversation.history}")
conversation.history = json.dumps(Conversations.get_new_history())
logging.debug(f"Cleared conversation. New history: {conversation.history}")
# Update Flask DB timestamp
conversation.updated = db.func.now()
db.session.commit()

@staticmethod
def create(owner: int, role: str, history: dict = None):
existing_definition = Conversations.query.filter_by(owner=owner).first()
if existing_definition is not None:
return existing_definition
if history is None:
raise ValueError("History must be provided when creating a new conversation")
conversation = Conversations(owner=owner, role=role, history=json.dumps(history))
db.session.add(conversation)
db.session.commit()
return conversation

@staticmethod
def get_by_owner(owner: int):
conversation = Conversations.query.filter_by(owner=owner).first()
conversation.history = json.loads(conversation.history)
return conversation

@staticmethod
def set_history(owner: int, history: dict):
conversation = Conversations.get_by_owner(owner)
conversation.history = json.dumps(history)
conversation.updated = db.func.now()
db.session.commit()
return conversation

@staticmethod
def get_new_history(role: str = None) -> list:
if role is None:
return []
return [{"role": "system", "message": role}]

@staticmethod
def get_history(owner: int):
conversation = Conversations.get_by_owner(owner)
# Unload it if it is a string:
if isinstance(conversation.history, str):
conversation.history = json.loads(conversation.history)
return conversation.history

@staticmethod
def set_role(owner: int, role: str):
conversation = Conversations.get_by_owner(owner)
conversation.role = role
conversation.updated = db.func.now()
db.session.commit()
return conversation

@staticmethod
def get_role(owner: int):
conversation = Conversations.get_by_owner(owner)
return conversation.role

def to_dict(self):
return {
'owner': self.owner,
'role': self.role,
'history': self.history,
'created': self.created,
'updated': self.updated
}
def to_json(self):
import json
return json.dumps(self.to_dict())

0 comments on commit dc01517

Please sign in to comment.