Skip to content

Commit

Permalink
fix(openai): back-port to support OpenAI TS SDK v3 (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
nirga authored Nov 8, 2023
1 parent 28238bd commit ed08f7b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 21 deletions.
1 change: 0 additions & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 84 additions & 20 deletions packages/instrumentation-openai/src/instrumentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,43 @@ export class OpenAIInstrumentation extends InstrumentationBase<any> {
super("@traceloop/instrumentation-openai", "0.0.17", config);
}

public manuallyInstrument(module: typeof openai.OpenAI) {
this._wrap(
module.Chat.Completions.prototype,
"create",
this.patchOpenAI("chat"),
);
this._wrap(
module.Completions.prototype,
"create",
this.patchOpenAI("completion"),
);
public manuallyInstrument(
module: typeof openai.OpenAI & { openLLMetryPatched?: boolean },
) {
if (module.openLLMetryPatched) {
return;
}

// Old version of OpenAI API (v3.1.0)
if ((module as any).OpenAIApi) {
this._wrap(
(module as any).OpenAIApi.prototype,
"createChatCompletion",
this.patchOpenAI("chat", "v3"),
);
this._wrap(
(module as any).OpenAIApi.prototype,
"createCompletion",
this.patchOpenAI("completion", "v3"),
);
} else {
this._wrap(
module.Chat.Completions.prototype,
"create",
this.patchOpenAI("chat"),
);
this._wrap(
module.Completions.prototype,
"create",
this.patchOpenAI("completion"),
);
}
}

protected init(): InstrumentationModuleDefinition<any> {
const module = new InstrumentationNodeModuleDefinition<any>(
"openai",
[">=4 <5"],
[">=3.1.0 <5"],
this.patch.bind(this),
this.unpatch.bind(this),
);
Expand All @@ -68,7 +88,23 @@ export class OpenAIInstrumentation extends InstrumentationBase<any> {
private patch(
moduleExports: typeof openai & { openLLMetryPatched?: boolean },
) {
if (!moduleExports.openLLMetryPatched) {
if (moduleExports.openLLMetryPatched) {
return moduleExports;
}

// Old version of OpenAI API (v3.1.0)
if ((moduleExports as any).OpenAIApi) {
this._wrap(
(moduleExports as any).OpenAIApi.prototype,
"createChatCompletion",
this.patchOpenAI("chat", "v3"),
);
this._wrap(
(moduleExports as any).OpenAIApi.prototype,
"createCompletion",
this.patchOpenAI("completion", "v3"),
);
} else {
moduleExports.openLLMetryPatched = true;
this._wrap(
moduleExports.OpenAI.Chat.Completions.prototype,
Expand All @@ -85,11 +121,26 @@ export class OpenAIInstrumentation extends InstrumentationBase<any> {
}

private unpatch(moduleExports: typeof openai): void {
this._unwrap(moduleExports.OpenAI.Chat.Completions.prototype, "create");
this._unwrap(moduleExports.OpenAI.Completions.prototype, "create");
// Old version of OpenAI API (v3.1.0)
if ((moduleExports as any).OpenAIApi) {
this._unwrap(
(moduleExports as any).OpenAIApi.prototype,
"createChatCompletion",
);
this._unwrap(
(moduleExports as any).OpenAIApi.prototype,
"createCompletion",
);
} else {
this._unwrap(moduleExports.OpenAI.Chat.Completions.prototype, "create");
this._unwrap(moduleExports.OpenAI.Completions.prototype, "create");
}
}

private patchOpenAI(type: "chat" | "completion") {
private patchOpenAI(
type: "chat" | "completion",
version: "v3" | "v4" = "v4",
) {
// eslint-disable-next-line @typescript-eslint/no-this-alias
const plugin = this;
// eslint-disable-next-line @typescript-eslint/ban-types
Expand Down Expand Up @@ -126,7 +177,7 @@ export class OpenAIInstrumentation extends InstrumentationBase<any> {
},
);

const wrappedPromise = wrapPromise(type, span, execPromise);
const wrappedPromise = wrapPromise(type, version, span, execPromise);

return context.bind(execContext, wrappedPromise as any);
};
Expand Down Expand Up @@ -206,16 +257,29 @@ export class OpenAIInstrumentation extends InstrumentationBase<any> {

function wrapPromise<T>(
type: "chat" | "completion",
version: "v3" | "v4",
span: Span,
promise: Promise<T>,
): Promise<T> {
return promise
.then((result) => {
return new Promise<T>((resolve) => {
if (type === "chat") {
endSpan({ type, span, result: result as ChatCompletion });
if (version === "v3") {
if (type === "chat") {
endSpan({
type,
span,
result: (result as any).data as ChatCompletion,
});
} else {
endSpan({ type, span, result: (result as any).data as Completion });
}
} else {
endSpan({ type, span, result: result as Completion });
if (type === "chat") {
endSpan({ type, span, result: result as ChatCompletion });
} else {
endSpan({ type, span, result: result as Completion });
}
}
resolve(result);
});
Expand Down

0 comments on commit ed08f7b

Please sign in to comment.