Skip to content

Commit

Permalink
Add AsBuilder extensions for IChatClient and IEmbeddingGenerator (#5652)
Browse files Browse the repository at this point in the history
* Add ToBuilder extensions for IChatClient and IEmbeddingGenerator

Enables a fluent style of construction of a pipeline from a client/generator, and not having to specify the generic type parameters for the embedding generator builder.

* Rename ToBuilder to AsBuilder
  • Loading branch information
stephentoub authored Nov 18, 2024
1 parent c468947 commit 930af05
Show file tree
Hide file tree
Showing 22 changed files with 115 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;

/// <summary>Provides extension methods for working with <see cref="IChatClient"/> in the context of <see cref="ChatClientBuilder"/>.</summary>
public static class ChatClientBuilderChatClientExtensions
{
/// <summary>Creates a new <see cref="ChatClientBuilder"/> using <paramref name="innerClient"/> as its inner client.</summary>
/// <param name="innerClient">The client to use as the inner client.</param>
/// <returns>The new <see cref="ChatClientBuilder"/> instance.</returns>
/// <remarks>
/// This method is equivalent to using the <see cref="ChatClientBuilder"/> constructor directly,
/// specifying <paramref name="innerClient"/> as the inner client.
/// </remarks>
public static ChatClientBuilder AsBuilder(this IChatClient innerClient)
{
_ = Throw.IfNull(innerClient);

return new ChatClientBuilder(innerClient);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;

/// <summary>Provides extension methods for working with <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>
/// in the context of <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/>.</summary>
public static class EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions
{
/// <summary>
/// Creates a new <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> using
/// <paramref name="innerGenerator"/> as its inner generator.
/// </summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The type of embeddings to generate.</typeparam>
/// <param name="innerGenerator">The generator to use as the inner generator.</param>
/// <returns>The new <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> instance.</returns>
/// <remarks>
/// This method is equivalent to using the <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/>
/// constructor directly, specifying <paramref name="innerGenerator"/> as the inner generator.
/// </remarks>
public static EmbeddingGeneratorBuilder<TInput, TEmbedding> AsBuilder<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> innerGenerator)
where TEmbedding : Embedding
{
_ = Throw.IfNull(innerGenerator);

return new EmbeddingGeneratorBuilder<TInput, TEmbedding>(innerGenerator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public void GetService_SuccessfullyReturnsUnderlyingClient()

Assert.Same(client, chatClient.GetService<ChatCompletionsClient>());

using IChatClient pipeline = new ChatClientBuilder(chatClient)
using IChatClient pipeline = chatClient
.AsBuilder()
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public void GetService_SuccessfullyReturnsUnderlyingClient()
Assert.Same(embeddingGenerator, embeddingGenerator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Same(client, embeddingGenerator.GetService<EmbeddingsClient>());

using IEmbeddingGenerator<string, Embedding<float>> pipeline = new EmbeddingGeneratorBuilder<string, Embedding<float>>(embeddingGenerator)
using IEmbeddingGenerator<string, Embedding<float>> pipeline = embeddingGenerator
.AsBuilder()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ public virtual async Task Caching_BeforeFunctionInvocation_AvoidsExtraCalls()
}, "GetTemperature");

// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
using var chatClient = CreateChatClient()!
.AsBuilder()
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.UseFunctionInvocation()
Expand Down Expand Up @@ -415,7 +416,8 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange
}, "GetTemperature");

// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
using var chatClient = CreateChatClient()!
.AsBuilder()
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseFunctionInvocation()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
Expand Down Expand Up @@ -454,7 +456,8 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedA
}, "GetTemperature");

// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
using var chatClient = CreateChatClient()!
.AsBuilder()
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseFunctionInvocation()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
Expand Down Expand Up @@ -573,7 +576,7 @@ public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics()
.AddInMemoryExporter(activities)
.Build();

var chatClient = new ChatClientBuilder(CreateChatClient()!)
var chatClient = CreateChatClient()!.AsBuilder()
.UseOpenTelemetry(sourceName: sourceName)
.Build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ public virtual async Task Caching_SameOutputsForSameInput()
{
SkipIfNotEnabled();

using var generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(CreateEmbeddingGenerator()!)
using var generator = CreateEmbeddingGenerator()!
.AsBuilder()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.UseCallCounting()
.Build();
Expand Down Expand Up @@ -110,7 +111,8 @@ public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics()
.AddInMemoryExporter(activities)
.Build();

var embeddingGenerator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(CreateEmbeddingGenerator()!)
var embeddingGenerator = CreateEmbeddingGenerator()!
.AsBuilder()
.UseOpenTelemetry(sourceName: sourceName)
.Build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ public async Task Reduction_LimitsMessagesBasedOnTokenLimit()
}
};

using var client = new ChatClientBuilder(innerClient)
using var client = innerClient
.AsBuilder()
.UseChatReducer(new TokenCountingChatReducer(_gpt4oTokenizer, 40))
.Build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ public async Task PromptBasedFunctionCalling_NoArgs()
{
SkipIfNotEnabled();

using var chatClient = new ChatClientBuilder(CreateChatClient()!)
using var chatClient = CreateChatClient()!
.AsBuilder()
.UseFunctionInvocation()
.UsePromptBasedFunctionCalling()
.Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient))
Expand All @@ -61,7 +62,8 @@ public async Task PromptBasedFunctionCalling_WithArgs()
{
SkipIfNotEnabled();

using var chatClient = new ChatClientBuilder(CreateChatClient()!)
using var chatClient = CreateChatClient()!
.AsBuilder()
.UseFunctionInvocation()
.UsePromptBasedFunctionCalling()
.Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public void GetService_SuccessfullyReturnsUnderlyingClient()
Assert.Same(client, client.GetService<OllamaChatClient>());
Assert.Same(client, client.GetService<IChatClient>());

using IChatClient pipeline = new ChatClientBuilder(client)
using IChatClient pipeline = client
.AsBuilder()
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public void GetService_SuccessfullyReturnsUnderlyingClient()
Assert.Same(generator, generator.GetService<OllamaEmbeddingGenerator>());
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());

using IEmbeddingGenerator<string, Embedding<float>> pipeline = new EmbeddingGeneratorBuilder<string, Embedding<float>>(generator)
using IEmbeddingGenerator<string, Embedding<float>> pipeline = generator
.AsBuilder()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ public void GetService_OpenAIClient_SuccessfullyReturnsUnderlyingClient()

Assert.NotNull(chatClient.GetService<ChatClient>());

using IChatClient pipeline = new ChatClientBuilder(chatClient)
using IChatClient pipeline = chatClient
.AsBuilder()
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
Expand All @@ -119,7 +120,8 @@ public void GetService_ChatClient_SuccessfullyReturnsUnderlyingClient()
Assert.Same(chatClient, chatClient.GetService<IChatClient>());
Assert.Same(openAIClient, chatClient.GetService<ChatClient>());

using IChatClient pipeline = new ChatClientBuilder(chatClient)
using IChatClient pipeline = chatClient
.AsBuilder()
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ public void GetService_OpenAIClient_SuccessfullyReturnsUnderlyingClient()

Assert.NotNull(embeddingGenerator.GetService<EmbeddingClient>());

using IEmbeddingGenerator<string, Embedding<float>> pipeline = new EmbeddingGeneratorBuilder<string, Embedding<float>>(embeddingGenerator)
using IEmbeddingGenerator<string, Embedding<float>> pipeline = embeddingGenerator
.AsBuilder()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Build();
Expand All @@ -100,7 +101,8 @@ public void GetService_EmbeddingClient_SuccessfullyReturnsUnderlyingClient()
Assert.Same(embeddingGenerator, embeddingGenerator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Same(openAIClient, embeddingGenerator.GetService<EmbeddingClient>());

using IEmbeddingGenerator<string, Embedding<float>> pipeline = new EmbeddingGeneratorBuilder<string, Embedding<float>>(embeddingGenerator)
using IEmbeddingGenerator<string, Embedding<float>> pipeline = embeddingGenerator
.AsBuilder()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public void BuildsPipelineInOrderAdded()
public void DoesNotAcceptNullInnerService()
{
Assert.Throws<ArgumentNullException>("innerClient", () => new ChatClientBuilder((IChatClient)null!));
Assert.Throws<ArgumentNullException>("innerClient", () => ((IChatClient)null!).AsBuilder());
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public void ConfigureOptionsChatClient_InvalidArgs_Throws()
public void ConfigureOptions_InvalidArgs_Throws()
{
using var innerClient = new TestChatClient();
var builder = new ChatClientBuilder(innerClient);
var builder = innerClient.AsBuilder();
Assert.Throws<ArgumentNullException>("configure", () => builder.ConfigureOptions(null!));
}

Expand Down Expand Up @@ -55,7 +55,8 @@ public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullP
},
};

using var client = new ChatClientBuilder(innerClient)
using var client = innerClient
.AsBuilder()
.ConfigureOptions(options =>
{
Assert.NotSame(providedOptions, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,8 @@ public async Task CanResolveIDistributedCacheFromDI()
new(ChatRole.Assistant, [new TextContent("Hey")])]));
}
};
using var outer = new ChatClientBuilder(testClient)
using var outer = testClient
.AsBuilder()
.UseDistributedCache(configure: options =>
{
options.JsonSerializerOptions = TestJsonSerializerContext.Default.Options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ public async Task RejectsMultipleChoicesAsync()
}
};

IChatClient service = new ChatClientBuilder(innerClient).UseFunctionInvocation().Build();
IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build();

List<ChatMessage> chat = [new ChatMessage(ChatRole.User, "hello")];
var ex = await Assert.ThrowsAsync<InvalidOperationException>(
Expand Down Expand Up @@ -415,7 +415,7 @@ private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
}
};

IChatClient service = configurePipeline(new ChatClientBuilder(innerClient)).Build();
IChatClient service = configurePipeline(innerClient.AsBuilder()).Build();

var result = await service.CompleteAsync(chat, options, cts.Token);
chat.Add(result.Message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level)
},
};

using IChatClient client = new ChatClientBuilder(innerClient)
using IChatClient client = innerClient
.AsBuilder()
.UseLogging()
.Build(services);

Expand Down Expand Up @@ -86,7 +87,8 @@ static async IAsyncEnumerable<StreamingChatCompletionUpdate> GetUpdatesAsync()
yield return new StreamingChatCompletionUpdate { Role = ChatRole.Assistant, Text = "whale" };
}

using IChatClient client = new ChatClientBuilder(innerClient)
using IChatClient client = innerClient
.AsBuilder()
.UseLogging(logger)
.Build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ async static IAsyncEnumerable<StreamingChatCompletionUpdate> CallbackAsync(
};
}

var chatClient = new ChatClientBuilder(innerClient)
var chatClient = innerClient
.AsBuilder()
.UseOpenTelemetry(loggerFactory, sourceName, configure: instance =>
{
instance.EnableSensitiveData = enableSensitiveData;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public void ConfigureOptionsEmbeddingGenerator_InvalidArgs_Throws()
public void ConfigureOptions_InvalidArgs_Throws()
{
using var innerGenerator = new TestEmbeddingGenerator();
var builder = new EmbeddingGeneratorBuilder<string, Embedding<float>>(innerGenerator);
var builder = innerGenerator.AsBuilder();
Assert.Throws<ArgumentNullException>("configure", () => builder.ConfigureOptions(null!));
}

Expand All @@ -45,7 +45,8 @@ public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullP
}
};

using var generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(innerGenerator)
using var generator = innerGenerator
.AsBuilder()
.ConfigureOptions(options =>
{
Assert.NotSame(providedOptions, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ public async Task CanResolveIDistributedCacheFromDI()
return Task.FromResult<GeneratedEmbeddings<Embedding<float>>>([_expectedEmbedding]);
},
};
using var outer = new EmbeddingGeneratorBuilder<string, Embedding<float>>(testGenerator)
using var outer = testGenerator
.AsBuilder()
.UseDistributedCache(configure: instance =>
{
instance.JsonSerializerOptions = TestJsonSerializerContext.Default.Options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void BuildsPipelineInOrderAdded()
{
// Arrange
using var expectedInnerGenerator = new TestEmbeddingGenerator();
var builder = new EmbeddingGeneratorBuilder<string, Embedding<float>>(expectedInnerGenerator);
var builder = expectedInnerGenerator.AsBuilder();

builder.Use(next => new InnerServiceCapturingEmbeddingGenerator("First", next));
builder.Use(next => new InnerServiceCapturingEmbeddingGenerator("Second", next));
Expand All @@ -58,6 +58,7 @@ public void BuildsPipelineInOrderAdded()
public void DoesNotAcceptNullInnerService()
{
Assert.Throws<ArgumentNullException>("innerGenerator", () => new EmbeddingGeneratorBuilder<string, Embedding<float>>((IEmbeddingGenerator<string, Embedding<float>>)null!));
Assert.Throws<ArgumentNullException>("innerGenerator", () => ((IEmbeddingGenerator<string, Embedding<float>>)null!).AsBuilder());
}

[Fact]
Expand All @@ -71,7 +72,7 @@ public void DoesNotAcceptNullFactories()
public void DoesNotAllowFactoriesToReturnNull()
{
using var innerGenerator = new TestEmbeddingGenerator();
var builder = new EmbeddingGeneratorBuilder<string, Embedding<float>>(innerGenerator);
var builder = innerGenerator.AsBuilder();
builder.Use(_ => null!);
var ex = Assert.Throws<InvalidOperationException>(() => builder.Build());
Assert.Contains("entry at index 0", ex.Message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level)
},
};

using IEmbeddingGenerator<string, Embedding<float>> generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(innerGenerator)
using IEmbeddingGenerator<string, Embedding<float>> generator = innerGenerator
.AsBuilder()
.UseLogging()
.Build(services);

Expand Down

0 comments on commit 930af05

Please sign in to comment.