Skip to content

Commit

Permalink
Prompt template override in BaseProvider (#309)
Browse files Browse the repository at this point in the history
* Uses PromptTemplate class

* Code reorgnization for custom prompts

* Refactors prompt_template into BaseProvider

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adds prompt_template function for getting templates

* Adds update_prompt_template function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Makes prompt templates an instance var

* Renames getter function

* Uses underscore syntax for local variable

* Sets prompt_templates in ctor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Documents prompt templates

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
JasonWeill and pre-commit-ci[bot] authored Aug 3, 2023
1 parent f744b9a commit 51fd0d9
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 31 deletions.
14 changes: 14 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,20 @@ A function that computes the lowest common multiples of two integers, and
a function that runs 5 test cases of the lowest common multiple function
```

### Prompt templates

Each provider can define **prompt templates** for each supported format. A prompt
template guides the language model to produce output in a particular
format. The default prompt templates are a
[Python dictionary mapping formats to templates](/~https://github.com/jupyterlab/jupyter-ai/blob/57a758fa5cdd5a87da5519987895aa688b3766a8/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py#L138-L166).
Developers who write subclasses of `BaseProvider` can override templates per
output format, per model, and based on the prompt being submitted, by
implementing their own
[`get_prompt_template` function](/~https://github.com/jupyterlab/jupyter-ai/blob/57a758fa5cdd5a87da5519987895aa688b3766a8/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py#L186-L195).
Each prompt template includes the string `{prompt}`, which is replaced with
the user-provided prompt when the user runs a magic command.


### Clearing the OpenAI chat history

With the `openai-chat` provider *only*, you can run a cell magic command using the `-r` or
Expand Down
49 changes: 18 additions & 31 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def _repr_mimebundle_(self, include=None, exclude=None):

NA_MESSAGE = '<abbr title="Not applicable">N/A</abbr>'

MARKDOWN_PROMPT_TEMPLATE = "{prompt}\n\nProduce output in markdown format only."

PROVIDER_NO_MODELS = "This provider does not define a list of models."

CANNOT_DETERMINE_MODEL_TEXT = """Cannot determine model provider from model ID '{0}'.
Expand All @@ -93,17 +91,6 @@ def _repr_mimebundle_(self, include=None, exclude=None):
To see a list of models you can use, run `%ai list`"""


PROMPT_TEMPLATES_BY_FORMAT = {
"code": "{prompt}\n\nProduce output as source code only, with no text or explanation before or after it.",
"html": "{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.",
"image": "{prompt}\n\nProduce output as an image only, with no text before or after it.",
"markdown": MARKDOWN_PROMPT_TEMPLATE,
"md": MARKDOWN_PROMPT_TEMPLATE,
"math": "{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.",
"json": "{prompt}\n\nProduce output in JSON format only, with nothing before or after it.",
"text": "{prompt}", # No customization
}

AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"}


Expand Down Expand Up @@ -465,24 +452,6 @@ def handle_list(self, args: ListArgs):
)

def run_ai_cell(self, args: CellArgs, prompt: str):
# Apply a prompt template.
prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt=prompt)

# interpolate user namespace into prompt
ip = get_ipython()
prompt = prompt.format_map(FormatDict(ip.user_ns))

# Determine provider and local model IDs
# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

provider_id, local_model_id = self._decompose_model_id(args.model_id)
Provider = self._get_provider(provider_id)
if Provider is None:
Expand All @@ -500,6 +469,17 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
self.transcript_openai = []
return

# Determine provider and local model IDs
# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

# validate presence of authn credentials
auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy:
Expand Down Expand Up @@ -541,6 +521,13 @@ def run_ai_cell(self, args: CellArgs, prompt: str):

provider = Provider(**provider_params)

# Apply a prompt template.
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)

# interpolate user namespace into prompt
ip = get_ipython()
prompt = prompt.format_map(FormatDict(ip.user_ns))

# generate output from model via provider
result = provider.generate([prompt])
output = result.generations[0][0].text
Expand Down
52 changes: 52 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union

from jsonpath_ng import parse
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.llms import (
AI21,
Expand Down Expand Up @@ -117,6 +118,10 @@ class Config:
# instance attrs
#
model_id: str
prompt_templates: Dict[str, PromptTemplate]
"""Prompt templates for each output type. Can be overridden with
`update_prompt_template`. The function `prompt_template`, in the base class,
refers to this."""

def __init__(self, *args, **kwargs):
try:
Expand All @@ -130,6 +135,36 @@ def __init__(self, *args, **kwargs):
if self.__class__.model_id_key != "model_id":
model_kwargs[self.__class__.model_id_key] = kwargs["model_id"]

model_kwargs["prompt_templates"] = {
"code": PromptTemplate.from_template(
"{prompt}\n\nProduce output as source code only, "
"with no text or explanation before or after it."
),
"html": PromptTemplate.from_template(
"{prompt}\n\nProduce output in HTML format only, "
"with no markup before or afterward."
),
"image": PromptTemplate.from_template(
"{prompt}\n\nProduce output as an image only, "
"with no text before or after it."
),
"markdown": PromptTemplate.from_template(
"{prompt}\n\nProduce output in markdown format only."
),
"md": PromptTemplate.from_template(
"{prompt}\n\nProduce output in markdown format only."
),
"math": PromptTemplate.from_template(
"{prompt}\n\nProduce output in LaTeX format only, "
"with $$ at the beginning and end."
),
"json": PromptTemplate.from_template(
"{prompt}\n\nProduce output in JSON format only, "
"with nothing before or after it."
),
"text": PromptTemplate.from_template("{prompt}"), # No customization
}

super().__init__(*args, **kwargs, **model_kwargs)

async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
Expand All @@ -142,6 +177,23 @@ async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
_call_with_args = functools.partial(self._call, *args, **kwargs)
return await loop.run_in_executor(executor, _call_with_args)

def update_prompt_template(self, format: str, template: str):
"""
Changes the class-level prompt template for a given format.
"""
self.prompt_templates[format] = PromptTemplate.from_template(template)

def get_prompt_template(self, format) -> PromptTemplate:
"""
Produce a prompt template suitable for use with a particular model, to
produce output in a desired format.
"""

if format in self.prompt_templates:
return self.prompt_templates[format]
else:
return self.prompt_templates["text"] # Default to plain format


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down

0 comments on commit 51fd0d9

Please sign in to comment.