-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GPT token tracking and conversation history
Signed-off-by: bghira <bghira@users.github.com>
- Loading branch information
bghira
committed
Apr 15, 2023
1 parent
d8a48c0
commit dc01517
Showing
6 changed files
with
290 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# A class for managing ChatML histories via Flask DB. | ||
from discord_tron_master.classes.app_config import AppConfig | ||
from discord_tron_master.models.conversation import Conversations | ||
from discord_tron_master.classes.openai.tokens import TokenTester | ||
|
||
import json, logging | ||
|
||
config = AppConfig() | ||
app = AppConfig.flask | ||
|
||
if app is None: | ||
raise Exception("Flask app is not initialized.") | ||
|
||
class ChatML: | ||
def __init__(self, conversation: Conversations, token_limit: int = 2048): | ||
self.conversations = conversation | ||
self.user_id = conversation.owner | ||
self.history = conversation.get_history(self.user_id) or Conversations.get_new_history() | ||
self.user_config = config.get_user_config(self.user_id) | ||
# Pick up their current role from their profile. | ||
self.role = self.user_config["gpt_role"] | ||
self.reply = {} | ||
self.tokenizer = TokenTester() | ||
self.token_limit = token_limit | ||
|
||
# Pick up a DB connector and store it. Create the conversation, if needed. | ||
async def initialize_conversation(self): | ||
with app.app_context(): | ||
self.conversations = Conversations() | ||
self.conversation = await self.get_conversation_or_create() | ||
|
||
async def get_conversation_or_create(self): | ||
with app.app_context(): | ||
conversation = self.conversations.get_by_owner(self.user_id) | ||
logging.debug(f"Picked up conversation from db: {conversation}") | ||
if conversation is None: | ||
conversation = self.conversations.create(self.user_id, self.role, self.history) | ||
return conversation | ||
|
||
async def validate_reply(self): | ||
# If we are too long, maybe we can clean it up. | ||
logging.debug(f"Validating reply") | ||
if await self.is_reply_too_long(): | ||
# let's clean up until it does fit. | ||
logging.debug(f"Eureka! We can enter Alzheimers mode.") | ||
await self.remove_history_until_reply_fits() | ||
return True | ||
|
||
# See if we can fit everything without emptying the history too far. | ||
async def can_new_reply_fit_without_emptying_everything_from_history(self): | ||
if await self.is_history_empty() and not await self.is_reply_too_long(): | ||
logging.debug(f"History is empty and reply is short enough to fit.") | ||
return True | ||
# If we are not empty, we can fit the reply if the history is short enough. | ||
if await self.get_history_token_count() < self.token_limit: | ||
logging.debug(f"History is short enough to fit.") | ||
return True | ||
logging.debug(f"Returning true by default. Maybe this should be a false..") | ||
return True | ||
|
||
# Loop over the history and remove items until the new reply will fit with the current text. | ||
async def remove_history_until_reply_fits(self): | ||
logging.debug(f"Stripping conversation back until the reply fits.") | ||
while await self.is_reply_too_long() and len(await self.get_history()) > 0: | ||
logging.debug(f"Reply is too long. Removing oldest history item.") | ||
await self.remove_oldest_history_item() | ||
logging.debug(f"Cleanup is complete. Returning newly pruned history.") | ||
return await self.get_history() | ||
|
||
# Remove the oldest history item and return the new history. | ||
async def remove_oldest_history_item(self): | ||
conversation = await self.get_conversation_or_create() | ||
item = conversation.history.pop(0) | ||
logging.debug(f"Removing oldest history item: {item}") | ||
with app.app_context(): | ||
Conversations.set_history(self.user_id, conversation.history) | ||
return Conversations.get_history(owner=self.user_id) | ||
|
||
# Look at the actual token counts of each item and compare against our limit. | ||
async def is_reply_too_long(self): | ||
reply_token_count = await self.get_reply_token_count() | ||
history_token_count = await self.get_history_token_count() | ||
logging.debug(f"Reply token count: {reply_token_count}") | ||
logging.debug(f"History token count: {history_token_count}") | ||
if reply_token_count + history_token_count > self.token_limit: | ||
return True | ||
return False | ||
async def get_reply_token_count(self): | ||
return self.tokenizer.get_token_count(json.dumps(self.reply)) | ||
async def get_history_token_count(self): | ||
# Pad the value by 64 to accommodate for the metadata in the JSON we can't really count right here. | ||
return self.tokenizer.get_token_count(json.dumps(await self.get_history())) + 512 | ||
|
||
# Format the history as a string for OpenAI. | ||
async def get_prompt(self): | ||
return json.dumps(await self.get_history()) | ||
|
||
async def get_history(self): | ||
conversation = await self.get_conversation_or_create() | ||
logging.debug(f"Conversation: {conversation}") | ||
return conversation.history | ||
|
||
async def is_history_empty(self): | ||
history = await self.get_history() | ||
if len(history) == 0: | ||
return True | ||
if history == Conversations.get_new_history(): | ||
return True | ||
return False | ||
|
||
async def add_user_reply(self, content: str): | ||
return await self.add_to_history("user", content) | ||
|
||
async def add_system_reply(self, content: str): | ||
return await self.add_to_history("system", content) | ||
|
||
async def add_assistant_reply(self, content: str): | ||
return await self.add_to_history("assistant", content) | ||
|
||
async def add_to_history(self, role: str, content: str): | ||
# Store the reply for processing | ||
self.reply = {"role": role, "content": content} | ||
if not await self.validate_reply(): | ||
raise ValueError(f"I am sorry. It seems your reply would overrun the limits of reality and time. We are currently stuck at {self.token_limit} tokens, and your message used {await self.get_reply_token_count()} tokens. Please try again.") | ||
with app.app_context(): | ||
conversation = await self.get_conversation_or_create() | ||
conversation.history.append(self.reply) | ||
self.conversations.set_history(self.user_id, conversation.history) | ||
return conversation.history | ||
|
||
|
||
def truncate_conversation_history(self, conversation_history, new_prompt, max_tokens=2048): | ||
# Calculate tokens for new_prompt | ||
new_prompt_token_count = self.tokenizer.get_token_count(new_prompt) | ||
if new_prompt_token_count >= max_tokens: | ||
raise ValueError("The new prompt alone exceeds the maximum token limit.") | ||
|
||
# Calculate tokens for conversation_history | ||
conversation_history_token_counts = [len(self.tokenizer.tokenize(entry)) for entry in conversation_history] | ||
total_tokens = sum(conversation_history_token_counts) + new_prompt_token_count | ||
|
||
# Truncate conversation history if total tokens exceed max_tokens | ||
while total_tokens > max_tokens: | ||
conversation_history.pop(0) # Remove the oldest entry | ||
conversation_history_token_counts.pop(0) # Remove the oldest entry's token count | ||
total_tokens = sum(conversation_history_token_counts) + new_prompt_token_count | ||
|
||
return conversation_history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,13 @@ | ||
from tiktoken import Tokenizer | ||
from tiktoken.tokenizer import Tokenizer as OpenAITokenizer | ||
import tiktoken | ||
|
||
text = "This is a sample text to check the token count." | ||
|
||
tokenizer = OpenAITokenizer() | ||
tokens = tokenizer.tokenize(text) | ||
|
||
token_count = len(tokens) | ||
print(f"Token count: {token_count}") | ||
|
||
class TokenTester: | ||
def __init__(self, engine: str = "gpt-3.5-turbo"): | ||
self.tokenizer = tiktoken.encoding_for_model(engine) | ||
|
||
def tokenize(self, text): | ||
return tokenizer.tokenize(text) | ||
return self.tokenizer.encode(text, allowed_special='all') | ||
|
||
def get_token_count(self, text): | ||
tokens = self.tokenize(text) | ||
return len(tokens) | ||
|
||
def truncate_conversation_history(conversation_history, new_prompt, max_tokens=2048): | ||
tokenizer = OpenAITokenizer() | ||
|
||
# Calculate tokens for new_prompt | ||
new_prompt_tokens = tokenizer.tokenize(new_prompt) | ||
new_prompt_token_count = len(new_prompt_tokens) | ||
|
||
if new_prompt_token_count >= max_tokens: | ||
raise ValueError("The new prompt alone exceeds the maximum token limit.") | ||
|
||
# Calculate tokens for conversation_history | ||
conversation_history_token_counts = [len(tokenizer.tokenize(entry)) for entry in conversation_history] | ||
total_tokens = sum(conversation_history_token_counts) + new_prompt_token_count | ||
|
||
# Truncate conversation history if total tokens exceed max_tokens | ||
while total_tokens > max_tokens: | ||
conversation_history.pop(0) # Remove the oldest entry | ||
conversation_history_token_counts.pop(0) # Remove the oldest entry's token count | ||
total_tokens = sum(conversation_history_token_counts) + new_prompt_token_count | ||
|
||
return conversation_history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,27 @@ | ||
from discord.ext import commands | ||
from discord_tron_master.models.transformers import Transformers | ||
from discord_tron_master.models.conversation import Conversations | ||
from discord_tron_master.classes.text_replies import return_random as random_fact | ||
from discord_tron_master.classes.app_config import AppConfig | ||
import logging | ||
|
||
config = AppConfig() | ||
app = AppConfig.flask | ||
|
||
class User(commands.Cog): | ||
def __init__(self, bot): | ||
self.bot = bot | ||
|
||
@commands.command(name="clear", help="Clear your GPT conversation history and start again.") | ||
async def clear_history(self, ctx): | ||
user_id = ctx.author.id | ||
try: | ||
with app.app_context(): | ||
Conversations.clear_history_by_owner(owner=user_id) | ||
await ctx.send( | ||
f"{ctx.author.mention} Well, well, well. It is like I don't even know you anymore. Did you know {random_fact()}?" | ||
) | ||
except Exception as e: | ||
logging.error("Caught error when clearing user conversation history: " + str(e)) | ||
await ctx.send( | ||
f"{ctx.author.mention} The smoothbrain geriatric that writes my codebase did not correctly implement that method. I am sorry. Trying again will only lead to tears." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from .base import db | ||
import datetime, json | ||
|
||
class Conversations(db.Model): | ||
__tablename__ = 'conversations' | ||
id = db.Column(db.Integer, primary_key=True) | ||
owner = db.Column(db.BigInteger(), unique=False, nullable=False) | ||
role = db.Column(db.String(255), unique=True, nullable=False) | ||
history = db.Column(db.Text(), nullable=False, default='{}') | ||
created = db.Column(db.DateTime, nullable=False, default=db.func.now()) | ||
updated = db.Column(db.DateTime, nullable=False, default=db.func.now()) | ||
|
||
@staticmethod | ||
def get_all(): | ||
return Conversations.query.all() | ||
|
||
@staticmethod | ||
def delete_all(): | ||
all = Conversations.get_all() | ||
for conversation in all: | ||
db.session.delete(conversation) | ||
db.session.commit() | ||
|
||
@staticmethod | ||
def clear_history_by_owner(owner: int): | ||
conversation = Conversations.get_by_owner(owner) | ||
import logging | ||
logging.debug(f"Conversation before clearing: {conversation.history}") | ||
conversation.history = json.dumps(Conversations.get_new_history()) | ||
logging.debug(f"Cleared conversation. New history: {conversation.history}") | ||
# Update Flask DB timestamp | ||
conversation.updated = db.func.now() | ||
db.session.commit() | ||
|
||
@staticmethod | ||
def create(owner: int, role: str, history: dict = None): | ||
existing_definition = Conversations.query.filter_by(owner=owner).first() | ||
if existing_definition is not None: | ||
return existing_definition | ||
if history is None: | ||
raise ValueError("History must be provided when creating a new conversation") | ||
conversation = Conversations(owner=owner, role=role, history=json.dumps(history)) | ||
db.session.add(conversation) | ||
db.session.commit() | ||
return conversation | ||
|
||
@staticmethod | ||
def get_by_owner(owner: int): | ||
conversation = Conversations.query.filter_by(owner=owner).first() | ||
conversation.history = json.loads(conversation.history) | ||
return conversation | ||
|
||
@staticmethod | ||
def set_history(owner: int, history: dict): | ||
conversation = Conversations.get_by_owner(owner) | ||
conversation.history = json.dumps(history) | ||
conversation.updated = db.func.now() | ||
db.session.commit() | ||
return conversation | ||
|
||
@staticmethod | ||
def get_new_history(role: str = None) -> list: | ||
if role is None: | ||
return [] | ||
return [{"role": "system", "message": role}] | ||
|
||
@staticmethod | ||
def get_history(owner: int): | ||
conversation = Conversations.get_by_owner(owner) | ||
# Unload it if it is a string: | ||
if isinstance(conversation.history, str): | ||
conversation.history = json.loads(conversation.history) | ||
return conversation.history | ||
|
||
@staticmethod | ||
def set_role(owner: int, role: str): | ||
conversation = Conversations.get_by_owner(owner) | ||
conversation.role = role | ||
conversation.updated = db.func.now() | ||
db.session.commit() | ||
return conversation | ||
|
||
@staticmethod | ||
def get_role(owner: int): | ||
conversation = Conversations.get_by_owner(owner) | ||
return conversation.role | ||
|
||
def to_dict(self): | ||
return { | ||
'owner': self.owner, | ||
'role': self.role, | ||
'history': self.history, | ||
'created': self.created, | ||
'updated': self.updated | ||
} | ||
def to_json(self): | ||
import json | ||
return json.dumps(self.to_dict()) |