-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
294 lines (242 loc) · 11.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
from langsmith import Client
import pandas as pd
from typing import List, Dict
from langchain_community.utilities import SQLDatabase
import asyncio
import os
from concurrent.futures import ThreadPoolExecutor
import pickle
import re
from collections import Counter
from nltk.stem import WordNetLemmatizer
from langchain.schema import Document
def connect_db(type='url', **kwargs):
"""
Connects to a SQL database either through a URI string or an SQLAlchemy engine.
Parameters:
- type (str): Specifies the connection type ('url' for URI string, 'engine' for SQLAlchemy engine).
- **kwargs: Additional keyword arguments specific to the connection type.
Returns:
- SQLDatabase instance connected as specified.
"""
if type == 'url':
return SQLDatabase.from_uri(**kwargs)
elif type == 'engine':
return SQLDatabase(**kwargs)
def format_redshift_uri():
"""
Formats a Redshift database URI using environment variables for user and password.
Returns:
- A formatted URI string for connecting to a Redshift database.
"""
return f"redshift+psycopg2://{os.environ['redshift_user']}:{os.environ['redshift_pass']}@redshift-cluster-comp0087-demo.cvliubs5oipw.eu-west-2.redshift.amazonaws.com:5439/comp0087"
async def execute_sql_async(db: SQLDatabase, query: str, executor: ThreadPoolExecutor) -> str:
"""
Asynchronously executes a SQL query using a thread pool executor.
Parameters:
- db (SQLDatabase): The database connection object.
- query (str): The SQL query to execute.
- executor (ThreadPoolExecutor): The executor for asynchronous execution.
Returns:
- The query result as a string, or 'Error' if execution fails.
"""
loop = asyncio.get_running_loop()
try:
result = await loop.run_in_executor(executor, db.run, query)
except Exception as e:
result = 'Error'
return result
async def execute_all_queries(db: SQLDatabase, question_df: pd.DataFrame, input_col_name: str, output_col_name: str) -> pd.DataFrame:
"""
Asynchronously executes multiple SQL queries stored in a DataFrame column, updating the DataFrame with the results.
Parameters:
- db (SQLDatabase): The database connection object.
- question_df (pd.DataFrame): DataFrame containing the queries to execute.
- input_col_name (str): Name of the column containing the SQL queries.
- output_col_name (str): Name of the column where results will be stored.
Returns:
- The updated DataFrame with query results stored in 'output_col_name'.
"""
with ThreadPoolExecutor() as executor:
tasks = [
execute_sql_async(db, query, executor)
for query in question_df[input_col_name]
]
results = await asyncio.gather(*tasks)
question_df[output_col_name] = results
return question_df
def load_data_langsmith(df: pd.DataFrame, dataset_name: str, description: str, answer: bool = False) -> None:
"""
Loads data into Langsmith from a DataFrame, optionally including answers.
Parameters:
- df (pd.DataFrame): The DataFrame containing the data to load.
- dataset_name (str): The name for the new dataset in Langsmith.
- description (str): A description for the new dataset.
- answer (bool, optional): Whether to include answers in the data. Defaults to False.
The function creates a new dataset in Langsmith and populates it with examples
from the DataFrame. If 'answer' is True, each example includes a question, query, and answer.
Otherwise, examples include only questions and queries.
"""
client = Client()
dataset = client.create_dataset(
dataset_name=dataset_name,
description=description
)
if not answer:
client.create_examples(
inputs=[{"question": q} for q in df.question.values],
outputs=[{"query": q} for q in df["query"].values],
dataset_id=dataset.id,
)
else:
client.create_examples(
inputs=[{"question": q} for q in df.question.values],
outputs=[{"query": q, "answer": a} for q, a in zip(df["query"].values, df["answer"].values)],
dataset_id=dataset.id,
)
def load_testpkl_df(path: str) -> pd.DataFrame:
"""
Loads a DataFrame from a pickle file, specifically for test datasets.
Parameters:
- path (str): The file path to the pickle file containing the dataset.
Returns:
- pd.DataFrame: A DataFrame containing the test dataset, with columns for
the question text, database ID, and database path.
This function reads a pickle file into a DataFrame, selecting specific columns
('text', 'db_id', 'db_path') and renaming 'text' to 'question'.
"""
with open(path, 'rb') as file:
data = pickle.load(file)
df = pd.DataFrame(data)[['text', 'db_id', 'db_path']].rename(columns={'text': 'question'})
return df
# parse the sql schema
def sql_parse(text: str) -> Dict[str, List[str]]:
sql_split = text.split("\n")
schema = {}
for text in sql_split:
table_match = re.search(r'CREATE\s+TABLE\s+("?)(\w+)("?)\s*\(', text)
if table_match and table_match.group(2) not in schema:
table_name = table_match.group(2)
schema[table_name] = []
column_match = re.search(r"^\t(.*?),", text)
if column_match:
term = column_match[0].split()
if ("PRIMARY" not in term and "KEY" not in term) or ("FOREIGN" not in term and "KEY" not in term):
col_name = term[0]
col_name = re.sub(r'"', '', col_name)
schema[table_name].append(col_name)
return schema
# Function to find if any three consecutive letters in word are in a column name
def contains_consecutive(word, column_name):
for i in range(len(word) - 3):
substring = word[i:i+4]
if substring in column_name:
return True
return False
# Mask the question
def mask_question(question, db_schema, common_schema_toks = None):
"""Mask the question, replace colname table name with <mask>
Args:
question: question to be encoded
db_schema: database schema
common_schema_toks (optional): Words to avoid masked. Defaults to None.
Returns:
str: masked question
"""
words = question.split()
# Create a WordNetLemmatizer object
lemmatizer = WordNetLemmatizer()
# Lemmatize each word in the list
words_lemmatize = [lemmatizer.lemmatize(word) for word in words]
if type(db_schema) == dict:
# for table, cols in db_schema['tables'].items():
for table, cols in db_schema.items():
cols.append(table)
# Split column names with '_', for example, 'employee_id' will be split to 'employee' and 'id'
schema_related_toks = [word.split('_') if '_' in word else word for word in cols]
schema_related_toks = [item for element in schema_related_toks for item in (element if isinstance(element, list) else [element])]
if common_schema_toks:
schema_related_toks += common_schema_toks
schema_related_toks = [lemmatizer.lemmatize(word.lower()) for word in schema_related_toks] # Lemmatise the schema related toks as well
for tok in schema_related_toks:
for i, word in enumerate(words_lemmatize):
if contains_consecutive(word.lower(), tok):
words[i] = '[MASK]'
elif word in schema_related_toks:
words[i] = '[MASK]'
return ' '.join(words)
elif type(db_schema) == list:
# Split column names with '_', for example, 'employee_id' will be split to 'employee' and 'id'
schema_related_toks = [word.split('_') if '_' in word else word for word in db_schema]
schema_related_toks = [item for element in schema_related_toks for item in (element if isinstance(element, list) else [element])]
if common_schema_toks:
schema_related_toks += common_schema_toks
schema_related_toks = [lemmatizer.lemmatize(word.lower()) for word in schema_related_toks] # Lemmatise the schema related toks as well
for tok in schema_related_toks:
for i, word in enumerate(words_lemmatize):
if contains_consecutive(word.lower(), tok):
words[i] = '[MASK]'
elif word in schema_related_toks:
words[i] = '[MASK]'
return ' '.join(words)
def get_schema_related_toks(row):
"""
Extracts and aggregates schema-related tokens from a row.
Parameters:
- row: A row from a DataFrame, expected to contain 'tables' and 'columns'.
Returns:
- List[str]: A list of schema-related tokens including table names and columns.
"""
schema_toks = [table[0] for table in row['tables']]
schema_toks += row['columns']
return schema_toks
def mask_question_df(row, common_schema_related_toks):
"""
Masks schema-related tokens in a question based on common schema tokens.
Parameters:
- row: A row from a DataFrame, expected to contain 'question' and 'schema_toks'.
- common_schema_related_toks: List of common schema-related tokens to be masked in the question.
Returns:
- The masked question text.
"""
question, db_schema = row['question'], row['schema_toks']
return mask_question(question, db_schema, common_schema_related_toks)
def search_in_document(df: pd.DataFrame, keywords: List[str]) -> str:
"""
Searches for keywords in document titles and aggregates matching document content.
Parameters:
- df (pd.DataFrame): DataFrame containing document 'title' and 'content'.
- keywords (List[str]): List of keywords to search for in document titles.
Returns:
- str: Aggregated content of documents whose titles contain any of the keywords.
"""
res = []
for keyword in keywords:
mask = df['title'].apply(lambda x: keyword in str(x).lower())
res.extend(df[mask].content.to_list())
return '\n'.join(res)
def summarise_keywords_from_result(retrieved_docs: List[Document]) -> List[str]:
"""
Summarizes keywords from retrieved documents, filtering out common SQL keywords.
Parameters:
- retrieved_docs (List[Document]): List of retrieved Document objects with metadata.
Returns:
- List[str]: List of relevant, summarized keywords from the documents' metadata.
"""
keywords = [item.metadata['sql_keywords'] for item in retrieved_docs]
keywords = [item.split(',') for item in keywords]
keywords = [words for lst in keywords for words in lst]
word_counts = dict(Counter(keywords))
keywords_selected = [word for word, count in word_counts.items() if count >= 2]
ignore_keywords = ['select', 'from']
return [keyword for keyword in keywords_selected if keyword not in ignore_keywords]
def summarise_qa_from_result(retrieved_docs: List[Document]) -> str:
"""
Summarizes questions and answers from retrieved documents.
Parameters:
- retrieved_docs (List[Document]): List of retrieved Document objects with metadata.
Returns:
- str: Formatted summary of questions and their corresponding answers from documents.
"""
res = ['Question: ' + item.metadata['question'] + '\n' + 'Answer: ' + item.metadata['query'] for item in retrieved_docs]
return '\n'.join(res)