Skip to content

Commit

Permalink
Add feature to create a custom agent directly from the side panel wit…
Browse files Browse the repository at this point in the history
…h currently configured settings

- Also, when in not subscribed state, fallback to the default model when chatting with an agent
- With conversion, create a brand new agent from inside the chat view that can be managed separately
  • Loading branch information
sabaimran committed Feb 13, 2025
1 parent 5d6eca4 commit d0d30ac
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 12 deletions.
9 changes: 9 additions & 0 deletions src/interface/web/app/agents/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import { AppSidebar } from "../components/appSidebar/appSidebar";
import { Separator } from "@/components/ui/separator";
import { KhojLogoType } from "../components/logo/khojLogo";
import { DialogTitle } from "@radix-ui/react-dialog";
import Link from "next/link";

const agentsFetcher = () =>
window
Expand Down Expand Up @@ -343,6 +344,14 @@ export default function Agents() {
/>
<span className="font-bold">How it works</span> Use any of these
specialized personas to tune your conversation to your needs.
{
!isSubscribed && (
<span>
{" "}
<Link href="/settings" className="font-bold">Upgrade your plan</Link> to leverage custom models. You will fallback to the default model when chatting.
</span>
)
}
</AlertDescription>
</Alert>
<div className="pt-6 md:pt-8">
Expand Down
2 changes: 1 addition & 1 deletion src/interface/web/app/components/agentCard/agentCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ export function AgentCard(props: AgentCardProps) {
/>
</DialogContent>
) : (
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] rounded-lg">
<DialogContent className="whitespace-pre-line max-h-[80vh] max-w-[90vw] md:max-w-[50vw] rounded-lg">
<DialogHeader>
<div className="flex items-center">
{getIconFromIconName(props.data.icon, props.data.color)}
Expand Down
264 changes: 258 additions & 6 deletions src/interface/web/app/components/chatSidebar/chatSidebar.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client"

import { ArrowsDownUp, CaretCircleDown, CircleNotch, Sparkle } from "@phosphor-icons/react";
import { ArrowsDownUp, CaretCircleDown, CheckCircle, Circle, CircleNotch, PersonSimpleTaiChi, Sparkle } from "@phosphor-icons/react";

import { Button } from "@/components/ui/button";

Expand All @@ -14,13 +14,20 @@ import { mutate } from "swr";
import { Sheet, SheetContent } from "@/components/ui/sheet";
import { AgentData } from "../agentCard/agentCard";
import { useEffect, useState } from "react";
import { getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils";
import { getAvailableIcons, getIconForSlashCommand, getIconFromIconName } from "@/app/common/iconUtils";
import { Label } from "@/components/ui/label";
import { Checkbox } from "@/components/ui/checkbox";
import { Tooltip, TooltipTrigger } from "@/components/ui/tooltip";
import { TooltipContent } from "@radix-ui/react-tooltip";
import { useAuthenticatedData } from "@/app/common/auth";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { Dialog, DialogClose, DialogContent, DialogFooter, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog";
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
import { convertColorToTextClass, tailwindColors } from "@/app/common/colorUtils";
import { Input } from "@/components/ui/input";
import Link from "next/link";
import { motion } from "framer-motion";


interface ChatSideBarProps {
conversationId: string;
Expand Down Expand Up @@ -54,11 +61,245 @@ export function ChatSidebar({ ...props }: ChatSideBarProps) {
);
}

interface IAgentCreationProps {
customPrompt: string;
selectedModel: string;
inputTools: string[];
outputModes: string[];
}

interface AgentError {
detail: string;
}

function AgentCreationForm(props: IAgentCreationProps) {
const iconOptions = getAvailableIcons();
const colorOptions = tailwindColors;

const [isCreating, setIsCreating] = useState<boolean>(false);
const [customAgentName, setCustomAgentName] = useState<string | undefined>();
const [customAgentIcon, setCustomAgentIcon] = useState<string | undefined>();
const [customAgentColor, setCustomAgentColor] = useState<string | undefined>();

const [doneCreating, setDoneCreating] = useState<boolean>(false);
const [createdSlug, setCreatedSlug] = useState<string | undefined>();
const [isValid, setIsValid] = useState<boolean>(false);
const [error, setError] = useState<string | undefined>();

function createAgent() {
if (isCreating) {
return;
}

setIsCreating(true);

const data = {
name: customAgentName,
icon: customAgentIcon,
color: customAgentColor,
persona: props.customPrompt,
chat_model: props.selectedModel,
input_tools: props.inputTools,
output_modes: props.outputModes,
privacy_level: "private",
};

const createAgentUrl = `/api/agents`;

fetch(createAgentUrl, {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify(data)
})
.then((res) => res.json())
.then((data: AgentData | AgentError) => {
console.log("Success:", data);
if ('detail' in data) {
setError(`Error creating agent: ${data.detail}`);
setIsCreating(false);
return;
}
setDoneCreating(true);
setCreatedSlug(data.slug);
setIsCreating(false);
})
.catch((error) => {
console.error("Error:", error);
setError(`Error creating agent: ${error}`);
setIsCreating(false);
});
}

useEffect(() => {
if (customAgentName && customAgentIcon && customAgentColor) {
setIsValid(true);
} else {
setIsValid(false);
}
}, [customAgentName, customAgentIcon, customAgentColor]);

return (

<Dialog>
<DialogTrigger asChild>
<Button
className="p-1"
variant="ghost"
>
Create Agent
</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
{
doneCreating && createdSlug ? (
<DialogTitle>
Created {customAgentName}
</DialogTitle>
) : (
<DialogTitle>
Create a New Agent
</DialogTitle>
)
}
<DialogClose />
</DialogHeader>
<div className="py-4">
{
doneCreating && createdSlug ? (
<div className="flex flex-col items-center justify-center gap-4 py-8">
<motion.div
initial={{ scale: 0 }}
animate={{ scale: 1 }}
transition={{
type: "spring",
stiffness: 260,
damping: 20
}}
>
<CheckCircle
className="w-16 h-16 text-green-500"
weight="fill"
/>
</motion.div>
<motion.p
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.2 }}
className="text-center text-lg font-medium text-accent-foreground"
>
Created successfully!
</motion.p>
<motion.div
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.4 }}
>
<Link href={`/agents?agent=${createdSlug}`}>
<Button variant="secondary" className="mt-2">
Manage Agent
</Button>
</Link>
</motion.div>
</div>
) :
<div className="flex flex-col gap-4">
<div>
<Label htmlFor="agent_name">Name</Label>
<Input
id="agent_name"
className="w-full p-2 border mt-4 border-slate-500 rounded-lg"
disabled={isCreating}
value={customAgentName}
onChange={(e) => setCustomAgentName(e.target.value)}
/>
</div>
<div className="flex gap-4">
<div className="flex-1">
<Select onValueChange={setCustomAgentColor} defaultValue={customAgentColor}>
<SelectTrigger className="w-full dark:bg-muted" disabled={isCreating}>
<SelectValue placeholder="Color" />
</SelectTrigger>
<SelectContent className="items-center space-y-1 inline-flex flex-col">
{colorOptions.map((colorOption) => (
<SelectItem key={colorOption} value={colorOption}>
<div className="flex items-center space-x-2">
<Circle
className={`w-6 h-6 mr-2 ${convertColorToTextClass(colorOption)}`}
weight="fill"
/>
{colorOption}
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex-1">
<Select onValueChange={setCustomAgentIcon} defaultValue={customAgentIcon}>
<SelectTrigger className="w-full dark:bg-muted" disabled={isCreating}>
<SelectValue placeholder="Icon" />
</SelectTrigger>
<SelectContent className="items-center space-y-1 inline-flex flex-col">
{iconOptions.map((iconOption) => (
<SelectItem key={iconOption} value={iconOption}>
<div className="flex items-center space-x-2">
{getIconFromIconName(
iconOption,
customAgentColor ?? "gray",
"w-6",
"h-6",
)}
{iconOption}
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
</div>
}
</div>
<DialogFooter>
{
error && (
<div className="text-red-500 text-sm">
{error}
</div>
)
}
{
!doneCreating && (
<Button
type="submit"
onClick={() => createAgent()}
disabled={isCreating || !isValid}
>
{
isCreating ?
<CircleNotch className="animate-spin" />
:
<PersonSimpleTaiChi />
}
Create
</Button>
)
}
<DialogClose />
</DialogFooter>
</DialogContent>
</Dialog >

)
}

function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
const [isEditable, setIsEditable] = useState<boolean>(false);
const { data: agentConfigurationOptions, error: agentConfigurationOptionsError } =
useSWR<AgentConfigurationOptions>("/api/agents/options", fetcher);
useSWR<AgentConfigurationOptions>("/api/agents/options", fetcher);

const { data: agentData, isLoading: agentDataLoading, error: agentDataError } = useSWR<AgentData>(`/api/agents/conversation?conversation_id=${props.conversationId}`, fetcher);
const {
Expand Down Expand Up @@ -211,9 +452,20 @@ function ChatSidebarInternal({ ...props }: ChatSideBarProps) {
</a>
</div>
) : (
<div className="flex items-center relative text-sm">
{getIconFromIconName("lightbulb", "orange")}
Chat Options
<div className="flex items-center relative text-sm justify-between">
<p>
Chat Options
</p>
{
isEditable && customPrompt && !isDefaultAgent && selectedModel && (
<AgentCreationForm
customPrompt={customPrompt}
selectedModel={selectedModel}
inputTools={inputTools ?? []}
outputModes={outputModes ?? []}
/>
)
}
</div>
)
}
Expand Down
6 changes: 4 additions & 2 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,8 +1356,10 @@ async def aget_conversation_starters(user: KhojUser, max_results=3):
return random.sample(all_questions, max_results)

@staticmethod
def get_valid_chat_model(user: KhojUser, conversation: Conversation):
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
agent: Agent = (
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
)
if agent and agent.chat_model:
chat_model = conversation.agent.chat_model
else:
Expand Down
3 changes: 2 additions & 1 deletion src/khoj/routers/api_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def get_agent_by_conversation(
conversation_id: str,
) -> Response:
user: KhojUser = request.user.object if request.user.is_authenticated else None
is_subscribed = has_required_scope(request, ["premium"])
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)

if not conversation:
Expand All @@ -132,7 +133,7 @@ async def get_agent_by_conversation(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name,
"chat_model": agent.chat_model.name if is_subscribed else None,
"has_files": has_files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
Expand Down
4 changes: 3 additions & 1 deletion src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
from starlette.authentication import has_required_scope, requires

from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
Expand Down Expand Up @@ -637,6 +637,7 @@ async def event_generator(q: str, images: list[str]):
chat_metadata: dict = {}
connection_alive = True
user: KhojUser = request.user.object
is_subscribed = has_required_scope(request, ["premium"])
event_delimiter = "␃🔚␗"
q = unquote(q)
train_of_thought = []
Expand Down Expand Up @@ -1251,6 +1252,7 @@ def collect_telemetry():
generated_mermaidjs_diagram,
program_execution_context,
generated_asset_results,
is_subscribed,
tracer,
)

Expand Down
Loading

0 comments on commit d0d30ac

Please sign in to comment.