-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchatbot.py
126 lines (107 loc) · 4.8 KB
/
chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from llama_index.llms.huggingface import HuggingFaceLLM
from transformers import BitsAndBytesConfig
from llama_index.core import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings
from llama_index.embeddings.langchain import LangchainEmbedding
from llama_index.core import Settings
from sqlalchemy import make_url
from llama_index.vector_stores.postgres import PGVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever
from llama_index.core import get_response_synthesizer
class AIChatBot:
def __init__(self, args):
print("Initializing AI Chatbot")
# Initialize model configuration and database connection here
self.model_name = args.model_name
self.embedding_model_name = args.embedding_model_name
self.database_connection_string = args.database_connection_string
self.database_name = args.database_name
self.table_name = args.table_name
self.system_prompt = args.system_prompt
self.context_window = args.context_window
self.max_new_tokens = args.max_new_tokens
self.device = args.device
self.chunk_size = args.chunk_size
self.chunk_overlap = args.chunk_overlap
self.top_k_index_to_return = args.top_k_index_to_return
def _setup_model(self):
print("Setting up model")
# This will wrap the default prompts that are internal to llama-index
query_wrapper_prompt = PromptTemplate("<|USER|>{query_str}<|ASSISTANT|>")
# Create the LLM using the HuggingFaceLLM class
llm = HuggingFaceLLM(
context_window=self.context_window,
max_new_tokens=self.max_new_tokens,
system_prompt=self.system_prompt,
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name=self.model_name,
model_name=self.model_name,
device_map=self.device,
generate_kwargs={"temperature": 0.2, "top_k": 5, "top_p": 0.95, "do_sample": True},
# uncomment this if using CUDA to reduce memory usage
# model_kwargs={
# # "torch_dtype": torch.float16
# 'quantization_config':quantization_config
# }
)
# Create the embedding model using the HuggingFaceBgeEmbeddings class
embed_model = LangchainEmbedding(
HuggingFaceBgeEmbeddings(model_name=self.embedding_model_name)
)
# Get the embedding dimension of the model by doing a forward pass with a dummy input
embed_dim = len(embed_model.get_text_embedding("Hello world")) # 1024
self.embed_dim = embed_dim
self.llm = llm
self.embed_model = embed_model
def _apply_settings(self):
print("Applying settings")
Settings.llm = self.llm
Settings.embed_model = self.embed_model
Settings.chunk_size = self.chunk_size
Settings.chunk_overlap = self.chunk_overlap
def _get_index_from_database(self):
print("Getting index from database")
# Creates a URL object from the connection string
url = make_url(self.database_connection_string)
# Create the vector store
vector_store = PGVectorStore.from_params(
database=self.database_name,
host=url.host,
password=url.password,
port=url.port,
user=url.username,
table_name=self.table_name,
embed_dim=self.embed_dim,
)
# Load the index from the vector store of the database
index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
self.index = index
def _setup_engine(self):
print("Setting up engine")
# Create the retriever that manages the index and the number of results to return
retriever = VectorIndexRetriever(
index=self.index,
similarity_top_k=self.top_k_index_to_return,
)
# Create the response synthesizer that will be used to synthesize the response
response_synthesizer = get_response_synthesizer(
response_mode='simple_summarize',
)
# Create the query engine that will be used to query the retriever and synthesize the response
engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer,
)
self.engine = engine
def build_bot(self):
self._setup_model()
self._apply_settings()
self._get_index_from_database()
self._setup_engine()
print("Bot built successfully")
def process_query(self, query_text):
response = self.engine.query(query_text)
return response