Skip to content

Commit

Permalink
Switch to OpenAI API
Browse files Browse the repository at this point in the history
  • Loading branch information
ajbozarth committed Oct 10, 2024
1 parent b0a8cd8 commit 6b42d4c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
67 changes: 40 additions & 27 deletions qiskit_code_assistant_jupyterlab/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
import os
from datetime import datetime
from pathlib import Path

import requests
Expand All @@ -27,7 +28,7 @@
runtime_configs = {
"service_url": "http://localhost",
"api_token": "",
"is_ollama": False,
"is_openai": False,
}


Expand Down Expand Up @@ -64,19 +65,19 @@ def get_header():
"Content-Type": "application/json",
"X-Caller": "qiskit-code-assistant-jupyterlab",
}
if not runtime_configs["is_ollama"]:
if not runtime_configs["is_openai"]:
header["Authorization"] = f"Bearer {runtime_configs['api_token']}"
return header


def convert_ollama(model):
def convert_openai(model):
return {
"_id": model["model"],
"_id": model["id"],
"disclaimer": {"accepted": "true"},
"display_name": model["name"],
"display_name": model["id"],
"doc_link": "",
"license": {"name": "", "link": ""},
"model_id": model["model"],
"model_id": model["id"],
"prompt_type": 1,
"token_limit": 255
}
Expand All @@ -95,8 +96,9 @@ def post(self):

try:
r = requests.get(url_path_join(runtime_configs["service_url"]), headers=get_header())
# TODO: Replace with a check against the QCA service instead
runtime_configs["is_ollama"] = ("Ollama is running" in r.text)
runtime_configs["is_openai"] = (r.json()["name"] != "qiskit-code-assistant")
except requests.exceptions.JSONDecodeError:
runtime_configs["is_openai"] = True
finally:
self.finish(json.dumps({"url": runtime_configs["service_url"]}))

Expand All @@ -105,7 +107,7 @@ class TokenHandler(APIHandler):
@tornado.web.authenticated
def get(self):
self.finish(json.dumps({"success": (runtime_configs["api_token"] != ""
or runtime_configs["is_ollama"])}))
or runtime_configs["is_openai"])}))

@tornado.web.authenticated
def post(self):
Expand All @@ -119,16 +121,16 @@ def post(self):
class ModelsHandler(APIHandler):
@tornado.web.authenticated
def get(self):
if runtime_configs["is_ollama"]:
url = url_path_join(runtime_configs["service_url"], "api", "tags")
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], "v1", "models")
models = []
try:
r = requests.get(url, headers=get_header())
r.raise_for_status()

if r.ok:
ollama_models = r.json()["models"]
models = list(map(convert_ollama, ollama_models))
data = r.json()["data"]
models = list(map(convert_openai, data))
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
Expand All @@ -150,9 +152,20 @@ def get(self):
class ModelHandler(APIHandler):
@tornado.web.authenticated
def get(self, id):
if runtime_configs["is_ollama"]:
self.set_status(501, "Not implemented")
self.finish()
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], "v1", "models", id)
model = {}
try:
r = requests.get(url, headers=get_header())
r.raise_for_status()

if r.ok:
model = convert_openai(r.json())
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(model))
else:
url = url_path_join(runtime_configs["service_url"], "model", id)

Expand All @@ -169,7 +182,7 @@ def get(self, id):
class DisclaimerHandler(APIHandler):
@tornado.web.authenticated
def get(self, id):
if runtime_configs["is_ollama"]:
if runtime_configs["is_openai"]:
self.set_status(501, "Not implemented")
self.finish()
else:
Expand All @@ -188,7 +201,7 @@ def get(self, id):
class DisclaimerAcceptanceHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
if runtime_configs["is_ollama"]:
if runtime_configs["is_openai"]:
self.set_status(501, "Not implemented")
self.finish()
else:
Expand All @@ -209,25 +222,25 @@ def post(self, id):
class PromptHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
if runtime_configs["is_ollama"]:
url = url_path_join(runtime_configs["service_url"], "api", "generate")
if runtime_configs["is_openai"]:
url = url_path_join(runtime_configs["service_url"], "v1", "completions")
result = {}
try:
r = requests.post(url,
headers=get_header(),
json={
"model": id,
"prompt": self.get_json_body()["input"],
"stream": False
"prompt": self.get_json_body()["input"]
})
r.raise_for_status()

if r.ok:
ollama_response = r.json()
response = r.json()
result = {
"results": [{"generated_text": ollama_response["response"]}],
"prompt_id": ollama_response["created_at"],
"created_at": ollama_response["created_at"]
"results": list(map(lambda c: {"generated_text": c["text"]},
response["choices"])),
"prompt_id": response["id"],
"created_at": datetime.fromtimestamp(int(response["created"])).isoformat()
}
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
Expand All @@ -250,7 +263,7 @@ def post(self, id):
class PromptAcceptanceHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
if runtime_configs["is_ollama"]:
if runtime_configs["is_openai"]:
self.finish(json.dumps({"success": "true"}))
else:
url = url_path_join(runtime_configs["service_url"], "prompt", id, "acceptance")
Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ const plugin: JupyterFrontEndPlugin<void> = {

postServiceUrl(settings.composite['serviceUrl'] as string);
settings.changed.connect(() =>
postServiceUrl(settings.composite['serviceUrl'] as string)
postServiceUrl(settings.composite['serviceUrl'] as string).then(() =>
refreshModelsList()
)
);

const provider = new QiskitCompletionProvider({ settings });
Expand Down

0 comments on commit 6b42d4c

Please sign in to comment.