Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vecDB integration example #42

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
abhijeethp marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
26 changes: 26 additions & 0 deletions examples/destination/python/vector_db/csv_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pandas as pd
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
import zstandard as zstd
from io import StringIO
from logger import log


class CSVReaderAESZSTD:

def read_csv(self, file_path, aes_key, null_string, timestamp_columns):
with open(file_path, 'rb') as encrypted_file:
iv = encrypted_file.read(16)
encrypted_data = encrypted_file.read()

cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted_data = unpad(cipher.decrypt(encrypted_data), AES.block_size)

decompressor = zstd.ZstdDecompressor()

with decompressor.stream_reader(decrypted_data) as reader:
decompressed_data = reader.read()

data_str = decompressed_data.decode('utf-8')
df = pd.read_csv(StringIO(data_str), na_values=null_string, parse_dates=timestamp_columns)
return df
41 changes: 41 additions & 0 deletions examples/destination/python/vector_db/destination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from datetime import datetime
from abc import ABC, abstractmethod

from models.collection import Collection, Row
from typing import Optional, Any, List

from sdk.common_pb2 import ConfigurationFormResponse


class VectorDestination(ABC):
"""Interface for vector destinations"""

@abstractmethod
def configuration_form(self) -> ConfigurationFormResponse:
"""configuration_form"""

@abstractmethod
def test(self, name: str, configuration: dict[str, str]) -> Optional[str]:
"""test"""

@abstractmethod
def create_collection_if_not_exists(self, configuration: dict[str, Any], collection: Collection) -> None:
"""create_collection"""

@abstractmethod
def upsert_rows(self, configuration: dict[str, Any], collection: Collection, rows: List[Row]) -> None:
"""upsert_rows"""

@abstractmethod
def delete_rows(self, configuration: dict[str, Any], collection: Collection, ids: List[str]) -> None:
"""delete_rows"""

@abstractmethod
def truncate(self, configuration: dict[str, Any], collection: Collection, synced_column: str, delete_before: datetime) -> None:
"""delete_rows"""

# Not Ideal but no clear winner ¯\_(ツ)_/¯
def get_collection_name(self, schema_name: str, table_name: str) -> str:
schema_name = schema_name.replace("_","-")
table_name = table_name.replace("_", "-")
return f"{schema_name}-{table_name}"
86 changes: 86 additions & 0 deletions examples/destination/python/vector_db/destinations/weaviate_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import uuid
import weaviate
from weaviate.auth import AuthApiKey
from weaviate.classes.query import Filter

from sdk.common_pb2 import ConfigurationFormResponse, FormField, TextField, ConfigurationTest
from destination import VectorDestination


class WeaviateDestination(VectorDestination):
def get_collection_name(self, schema_name, table_name):
return f"{schema_name}_{table_name}"

def configuration_form(self):
fields = [
FormField(name="url", label="Weaviate Cluster URL", required=True, text_field=TextField.PlainText),
FormField(name="api_key", label="Weaviate API Key", required=True, text_field=TextField.Password)
]
tests = [ConfigurationTest(name="connection_test", label="Connecting to Weaviate Cluster")]
return ConfigurationFormResponse(fields=fields, tests=tests)

def _get_client(self, configuration):
return weaviate.connect_to_wcs(
cluster_url=configuration["url"],
auth_credentials=AuthApiKey(configuration["api_key"])
)

def test(self, name, config):
if name != "connection_test":
raise ValueError(name)

client = self._get_client(config)

client.connect()
client.close()

def create_collection_if_not_exists(self, config, collection):
client = self._get_client(config)

if not client.collections.exists(collection.name):
print(f"Collection {collection.name} does not exist! Creating!")
client.collections.create(name=collection.name)

client.close()

def upsert_rows(self, config, collection, rows):
client = self._get_client(config)
c = client.collections.get(collection.name)

for row in rows:
_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(row.id))

# TODO: Get these swap column names from config
# TODO: swap column name in framework if other vec dbs have same issue.
id_swap_column = "_fvt_swp_id"
vector_swap_column = "_fvt_swp_vector"

if "id" in row.payload:
row.payload[id_swap_column] = row.payload.pop("id")
if "vector" in row.payload:
row.payload[vector_swap_column] = row.payload.pop("vector")

if c.data.exists(_uuid):
c.data.replace(uuid=_uuid, properties=row.payload, vector=row.vector)
else:
c.data.insert(uuid=_uuid, properties=row.payload, vector=row.vector)

client.close()

def delete_rows(self, config, collection, ids):
client = self._get_client(config)
c = client.collections.get(collection.name)

for id in ids:
_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(id))
c.data.delete_by_id(uuid=_uuid)

client.close()

def truncate(self, config, collection, synced_column, delete_before):
client = self._get_client(config)
c = client.collections.get(collection.name)

filter = Filter.by_property(synced_column).less_than(delete_before)
c.data.delete_many(where=filter)

31 changes: 31 additions & 0 deletions examples/destination/python/vector_db/embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod

from typing import Optional, Any, List, Tuple

from sdk.common_pb2 import ConfigurationFormResponse
from models.collection import Metrics


class Embedder(ABC):
"""Interface for embedders"""

@abstractmethod
def details(self)-> Tuple[str, str]:
"""details -> [id, name]"""


@abstractmethod
def configuration_form(self) -> ConfigurationFormResponse:
"""configuration_form"""

@abstractmethod
def metrics(self, configuration: dict[str, str])-> Metrics:
"""metrics"""

@abstractmethod
def test(self, name: str, configuration: dict[str, str]) -> Optional[str]:
"""test"""

@abstractmethod
def embed(self, configuration: dict[str, str], texts: List[str]) -> List[List[float]]:
"""embed"""
46 changes: 46 additions & 0 deletions examples/destination/python/vector_db/embedders/open_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from embedder import Embedder
from sdk.common_pb2 import ConfigurationFormResponse, FormField, TextField, DropdownField, ConfigurationTest
from langchain_openai import OpenAIEmbeddings
from models.collection import Metrics, Distance

MODELS = {
"text-embedding-ada-002": Metrics(
distance=Distance.COSINE,
dimensions=1536
)
}


class OpenAIEmbedder(Embedder):

def details(self):
return "open_ai", "OpenAI"

def configuration_form(self):
models = DropdownField(dropdown_field=list(MODELS.keys()))
fields = [
FormField(name="api_key", label="OpenAI API Key", required=True, text_field=TextField.Password),
FormField(name="embedding_model", label="OpenAI Embedding Model", required=True, dropdown_field=models),
]
tests = [ConfigurationTest(name="embedding_test", label="Checking OpenAI Embedding Generation")]
return ConfigurationFormResponse(fields=fields, tests=tests)

def metrics(self, config) -> Metrics:
return MODELS[config["embedding_model"]]

def _get_embedding(self, configuration):
api_key = configuration["api_key"]
model = configuration["embedding_model"]

return OpenAIEmbeddings(api_key=api_key, model=model)

def test(self, name, configuration):
if name != "embedding_test":
raise ValueError(f'Unknown test : {name}')

embedding = self._get_embedding(configuration)
embedding.embed_query("foo-bar-biz")

def embed(self, configuration, texts):
embedding = self._get_embedding(configuration)
return embedding.embed_documents(texts)
10 changes: 10 additions & 0 deletions examples/destination/python/vector_db/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import json


def log(msg: str):
m = {
"level": "INFO",
"message": msg,
"message-origin": "sdk_destination"
}
print(json.dumps(m), flush=True)
27 changes: 27 additions & 0 deletions examples/destination/python/vector_db/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import grpc
from concurrent import futures
from sdk.destination_sdk_pb2_grpc import add_DestinationServicer_to_server
from destination import VectorDestination
from service import VectorDestinationServicer
import sys


def serve(vec_dest: VectorDestination):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_DestinationServicer_to_server(VectorDestinationServicer(vec_dest), server)

if len(sys.argv) == 3 and sys.argv[1] == '--port':
port = int(sys.argv[2])
else:
port = 50052

server.add_insecure_port(f'[::]:{port}')
print(f"Running GRPC Server on {port}")
server.start()
server.wait_for_termination()


from destinations.weaviate_ import WeaviateDestination

if __name__ == '__main__':
serve(WeaviateDestination())
31 changes: 31 additions & 0 deletions examples/destination/python/vector_db/models/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import List, Dict, Any
from enum import Enum


class Distance(Enum):
COSINE = 1
DOT = 2
EUCLIDIAN = 3


@dataclass
class Metrics:
distance: Distance
dimensions: int


@dataclass
class Collection:
name: str
metrics: Metrics


@dataclass
class Row:
id: str
vector: List[float]
content: str
payload: Dict[str, Any]


9 changes: 9 additions & 0 deletions examples/destination/python/vector_db/requirements.txt
abhijeethp marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pandas
weaviate-client
langchain
langchain-community
langchain-openai
pycrypto
pycryptodome
zstandard
grpcio
Loading
Loading