Skip to content

Commit

Permalink
Merge pull request #841 from YanSte/cleanup
Browse files Browse the repository at this point in the history
cleanup code
  • Loading branch information
YanSte authored Feb 18, 2025
2 parents 99dc485 + 46e1865 commit 780d0b4
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 55 deletions.
17 changes: 0 additions & 17 deletions lightrag/kg/oracle_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ async def query(
await cursor.execute(sql, params)
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(sql)
print(params)
raise
columns = [column[0].lower() for column in cursor.description]
if multirows:
Expand Down Expand Up @@ -172,8 +170,6 @@ async def execute(self, sql: str, data: Union[list, dict] = None):
await connection.commit()
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(sql)
print(data)
raise


Expand Down Expand Up @@ -349,9 +345,7 @@ async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"top_k": top_k,
"better_than_threshold": self.cosine_better_than_threshold,
}
# print(SQL)
results = await self.db.query(SQL, params=params, multirows=True)
# print("vector search result:",results)
return results

async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
Expand Down Expand Up @@ -477,8 +471,6 @@ async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在"""
SQL = SQL_TEMPLATES["has_node"]
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
# print(self.db.workspace, node_id)
res = await self.db.query(SQL, params)
if res:
# print("Node exist!",res)
Expand All @@ -494,7 +486,6 @@ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"source_node_id": source_node_id,
"target_node_id": target_node_id,
}
# print(SQL)
res = await self.db.query(SQL, params)
if res:
# print("Edge exist!",res)
Expand All @@ -506,33 +497,25 @@ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
async def node_degree(self, node_id: str) -> int:
SQL = SQL_TEMPLATES["node_degree"]
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
res = await self.db.query(SQL, params)
if res:
# print("Node degree",res["degree"])
return res["degree"]
else:
# print("Edge not exist!")
return 0

async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""根据源和目标节点id获取边的度"""
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
# print("Edge degree",degree)
return degree

async def get_node(self, node_id: str) -> dict[str, str] | None:
"""根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"]
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(self.db.workspace, node_id)
# print(SQL)
res = await self.db.query(SQL, params)
if res:
# print("Get node!",self.db.workspace, node_id,res)
return res
else:
# print("Can't get node!",self.db.workspace, node_id)
return None

async def get_edge(
Expand Down
30 changes: 15 additions & 15 deletions lightrag/kg/postgres_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ async def query(
data = None
return data
except Exception as e:
logger.error(f"PostgreSQL database error: {e}")
print(sql)
print(params)
logger.error(
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
)
raise

async def execute(
Expand Down Expand Up @@ -167,9 +167,7 @@ async def execute(
else:
logger.error(f"Upsert error: {e}")
except Exception as e:
logger.error(f"PostgreSQL database error: {e.__class__} - {e}")
print(sql)
print(data)
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise

@staticmethod
Expand Down Expand Up @@ -266,9 +264,10 @@ async def filter_keys(self, keys: set[str]) -> set[str]:
new_keys = set([s for s in keys if s not in exist_keys])
return new_keys
except Exception as e:
logger.error(f"PostgreSQL database error: {e}")
print(sql)
print(params)
logger.error(
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
)
raise

################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
Expand Down Expand Up @@ -333,9 +332,9 @@ def _upsert_chunks(self, item: dict):
"content_vector": json.dumps(item["__vector__"].tolist()),
}
except Exception as e:
logger.error(f"Error to prepare upsert sql: {e}")
print(item)
raise e
logger.error(f"Error to prepare upsert,\nsql: {e}\nitem: {item}")
raise

return upsert_sql, data

def _upsert_entities(self, item: dict):
Expand Down Expand Up @@ -454,9 +453,10 @@ async def filter_keys(self, keys: set[str]) -> set[str]:
print(f"new_keys: {new_keys}")
return new_keys
except Exception as e:
logger.error(f"PostgreSQL database error: {e}")
print(sql)
print(params)
logger.error(
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
)
raise

async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
Expand Down
11 changes: 3 additions & 8 deletions lightrag/kg/tidb_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ async def query(
try:
result = conn.execute(text(sql), params)
except Exception as e:
logger.error(f"Tidb database error: {e}")
print(sql)
print(params)
logger.error(f"Tidb database,\nsql:{sql},\nparams:{params},\nerror:{e}")
raise
if multirows:
rows = result.all()
Expand All @@ -103,9 +101,7 @@ async def execute(self, sql: str, data: list | dict = None):
else:
conn.execute(text(sql), parameters=data)
except Exception as e:
logger.error(f"TiDB database error: {e}")
print(sql)
print(data)
logger.error(f"Tidb database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise


Expand Down Expand Up @@ -145,8 +141,7 @@ async def filter_keys(self, keys: set[str]) -> set[str]:
try:
await self.db.query(SQL)
except Exception as e:
logger.error(f"Tidb database error: {e}")
print(SQL)
logger.error(f"Tidb database,\nsql:{SQL},\nkeys:{keys},\nerror:{e}")
res = await self.db.query(SQL, multirows=True)
if res:
exist_keys = [key["id"] for key in res]
Expand Down
30 changes: 15 additions & 15 deletions lightrag/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from lightrag.api import __api_version__

import numpy as np
from typing import Union
from typing import Any, Union


class InvalidResponseError(Exception):
Expand All @@ -94,13 +94,13 @@ class InvalidResponseError(Exception):
),
)
async def openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=None,
base_url=None,
api_key=None,
**kwargs,
model: str,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
base_url: str | None = None,
api_key: str | None = None,
**kwargs: Any,
) -> str:
if history_messages is None:
history_messages = []
Expand All @@ -125,7 +125,7 @@ async def openai_complete_if_cache(
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
messages = []
messages: list[dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
Expand All @@ -147,18 +147,18 @@ async def openai_complete_if_cache(
model=model, messages=messages, **kwargs
)
except APIConnectionError as e:
logger.error(f"OpenAI API Connection Error: {str(e)}")
logger.error(f"OpenAI API Connection Error: {e}")
raise
except RateLimitError as e:
logger.error(f"OpenAI API Rate Limit Error: {str(e)}")
logger.error(f"OpenAI API Rate Limit Error: {e}")
raise
except APITimeoutError as e:
logger.error(f"OpenAI API Timeout Error: {str(e)}")
logger.error(f"OpenAI API Timeout Error: {e}")
raise
except Exception as e:
logger.error(f"OpenAI API Call Failed: {str(e)}")
logger.error(f"Model: {model}")
logger.error(f"Request parameters: {kwargs}")
logger.error(
f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}"
)
raise

if hasattr(response, "__aiter__"):
Expand Down

0 comments on commit 780d0b4

Please sign in to comment.