Skip to content

Commit

Permalink
Update Tests, add Llama Guard Test
Browse files Browse the repository at this point in the history
  • Loading branch information
MatKuhr committed Jan 17, 2025
1 parent d60505e commit 4908718
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
import com.github.tomakehurst.wiremock.stubbing.Scenario;
import com.sap.ai.sdk.orchestration.model.ChatMessage;
import com.sap.ai.sdk.orchestration.model.ChatMessagesInner;
import com.sap.ai.sdk.orchestration.model.CompletionPostRequest;
import com.sap.ai.sdk.orchestration.model.CompletionPostResponse;
import com.sap.ai.sdk.orchestration.model.DPIEntities;
Expand All @@ -47,9 +46,12 @@
import com.sap.ai.sdk.orchestration.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.model.LLMModuleResult;
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
import com.sap.ai.sdk.orchestration.model.LlamaGuard38b;
import com.sap.ai.sdk.orchestration.model.LlamaGuard38bFilterConfig;
import com.sap.ai.sdk.orchestration.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.model.MultiChatMessage;
import com.sap.ai.sdk.orchestration.model.OrchestrationConfig;
import com.sap.ai.sdk.orchestration.model.SingleChatMessage;
import com.sap.ai.sdk.orchestration.model.Template;
import com.sap.ai.sdk.orchestration.model.TextContent;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
Expand Down Expand Up @@ -147,7 +149,7 @@ void testGrounding() {
assertThat(groundingData.get("grounding_result").toString())
.startsWith("Joule is the AI copilot that truly understands your business.");
assertThat(result.getModuleResults().getGrounding().getMessage()).isEqualTo("grounding result");
assertThat(((ChatMessage) result.getModuleResults().getTemplating().get(0)).getContent())
assertThat(((SingleChatMessage) result.getModuleResults().getTemplating().get(0)).getContent())
.startsWith(
"What does Joule do? Use the following information as additional context: Joule is the AI copilot that truly understands your business.");
assertThat(llmChoice.getMessage().getContent())
Expand Down Expand Up @@ -258,14 +260,22 @@ void filteringLoose() throws IOException {
.withBodyFile("filteringLooseResponse.json")
.withHeader("Content-Type", "application/json")));

final var filter =
final var azureFilter =
new AzureContentFilter()
.hate(ALLOW_SAFE_LOW_MEDIUM)
.selfHarm(ALLOW_SAFE_LOW_MEDIUM)
.sexual(ALLOW_SAFE_LOW_MEDIUM)
.violence(ALLOW_SAFE_LOW_MEDIUM);

client.chatCompletion(prompt, config.withInputFiltering(filter).withOutputFiltering(filter));
ContentFilter llamaFilter =
() ->
LlamaGuard38bFilterConfig.create()
.type(LlamaGuard38bFilterConfig.TypeEnum.LLAMA_GUARD_3_8B)
.config(LlamaGuard38b.create().selfHarm(true));

client.chatCompletion(
prompt,
config.withInputFiltering(azureFilter, llamaFilter).withOutputFiltering(azureFilter));
// the result is asserted in the verify step below

// verify that null fields are absent from the sent request
Expand Down Expand Up @@ -603,7 +613,7 @@ void streamChatCompletionDeltas() throws IOException {
final var templating = deltaList.get(0).getModuleResults().getTemplating();
assertThat(templating).hasSize(1);

final var templateItem = (ChatMessage) templating.get(0);
final var templateItem = (SingleChatMessage) templating.get(0);
assertThat(templateItem.getRole()).isEqualTo("user");
assertThat(templateItem.getContent())
.isEqualTo("Hello world! Why is this phrase so famous?");
Expand Down Expand Up @@ -755,10 +765,9 @@ void testOrchestrationChatResponseWithMultiChatMessage() {
module.addDeserializer(
LLMModuleResult.class,
PolymorphicFallbackDeserializer.fromJsonSubTypes(LLMModuleResult.class));
module.setMixInAnnotation(ChatMessagesInner.class, JacksonMixins.NoneTypeInfoMixin.class);
module.setMixInAnnotation(ChatMessage.class, JacksonMixins.NoneTypeInfoMixin.class);
module.addDeserializer(
ChatMessagesInner.class,
PolymorphicFallbackDeserializer.fromJsonSubTypes(ChatMessagesInner.class));
ChatMessage.class, PolymorphicFallbackDeserializer.fromJsonSubTypes(ChatMessage.class));

var orchestrationChatResponse =
new OrchestrationChatResponse(
Expand Down
10 changes: 9 additions & 1 deletion orchestration/src/test/resources/filteringLooseRequest.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
"role": "user",
"content": "Hello World! Why is this phrase so famous?"
}
]
],
"defaults" : { },
"tools" : [ ]
},
"filtering_module_config": {
"input": {
Expand All @@ -32,6 +34,12 @@
"Sexual": 4,
"Violence": 4
}
},
{
"type": "llama_guard_3_8b",
"config": {
"self_harm": true
}
}
]
},
Expand Down
4 changes: 3 additions & 1 deletion orchestration/src/test/resources/maskingRequest.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
"role": "user",
"content": "Hello World! Why is this phrase so famous?"
}
]
],
"defaults" : { },
"tools" : [ ]
},
"masking_module_config": {
"masking_providers": [
Expand Down
4 changes: 3 additions & 1 deletion orchestration/src/test/resources/messagesHistoryRequest.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
"role": "user",
"content": "What is the typical food there?"
}
]
],
"defaults" : { },
"tools" : [ ]
}
},
"stream": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
}
]
}
]
],
"defaults" : { },
"tools" : [ ]
}
},
"stream": false
Expand Down
4 changes: 3 additions & 1 deletion orchestration/src/test/resources/templatingRequest.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
"role": "user",
"content": "{{?input}}"
}
]
],
"defaults" : { },
"tools" : [ ]
},
"llm_module_config": {
"model_name": "gpt-35-turbo-16k",
Expand Down

0 comments on commit 4908718

Please sign in to comment.