Skip to content

Commit 8bc8adb

Browse files
feat: [vertexai] add fluent API in GenerativeModel (#10585)
PiperOrigin-RevId: 617585215 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent bedcddf commit 8bc8adb

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,41 @@ public Builder setTools(List<Tool> tools) {
181181
}
182182
}
183183

184+
/**
185+
* Creates a copy of the current model with updated GenerationConfig.
186+
*
187+
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be
188+
* used in the new model.
189+
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
190+
*/
191+
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
192+
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
193+
}
194+
195+
/**
196+
* Creates a copy of the current model with updated safetySettings.
197+
*
198+
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will
199+
* be used in the new model.
200+
* @return a new {@link GenerativeModel} instance with the specified safetySettings.
201+
*/
202+
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
203+
return new GenerativeModel(
204+
modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi);
205+
}
206+
207+
/**
208+
* Creates a copy of the current model with updated tools.
209+
*
210+
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in
211+
* the new model.
212+
* @return a new {@link GenerativeModel} instance with the specified tools.
213+
*/
214+
public GenerativeModel withTools(List<Tool> tools) {
215+
return new GenerativeModel(
216+
modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi);
217+
}
218+
184219
/**
185220
* Counts tokens in a text message.
186221
*

java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,34 @@ public void testGenerateContentwithDefaultTools() throws Exception {
416416
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
417417
}
418418

419+
@Test
420+
public void testGenerateContentwithFluentApi() throws Exception {
421+
model = new GenerativeModel(MODEL_NAME, vertexAi);
422+
423+
Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
424+
field.setAccessible(true);
425+
field.set(vertexAi, mockPredictionServiceClient);
426+
427+
when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
428+
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
429+
.thenReturn(mockGenerateContentResponse);
430+
431+
GenerateContentResponse unused =
432+
model
433+
.withGenerationConfig(GENERATION_CONFIG)
434+
.withSafetySettings(safetySettings)
435+
.withTools(tools)
436+
.generateContent(TEXT);
437+
438+
ArgumentCaptor<GenerateContentRequest> request =
439+
ArgumentCaptor.forClass(GenerateContentRequest.class);
440+
verify(mockUnaryCallable).call(request.capture());
441+
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
442+
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
443+
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
444+
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
445+
}
446+
419447
@Test
420448
public void testGenerateContentStreamwithText() throws Exception {
421449
model = new GenerativeModel(MODEL_NAME, vertexAi);
@@ -569,4 +597,34 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception {
569597
verify(mockServerStreamCallable).call(request.capture());
570598
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
571599
}
600+
601+
@Test
602+
public void testGenerateContentStreamwithFluentApi() throws Exception {
603+
model = new GenerativeModel(MODEL_NAME, vertexAi);
604+
605+
Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
606+
field.setAccessible(true);
607+
field.set(vertexAi, mockPredictionServiceClient);
608+
609+
when(mockPredictionServiceClient.streamGenerateContentCallable())
610+
.thenReturn(mockServerStreamCallable);
611+
when(mockServerStreamCallable.call(any(GenerateContentRequest.class)))
612+
.thenReturn(mockServerStream);
613+
when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator);
614+
615+
ResponseStream unused =
616+
model
617+
.withGenerationConfig(GENERATION_CONFIG)
618+
.withSafetySettings(safetySettings)
619+
.withTools(tools)
620+
.generateContentStream(TEXT);
621+
622+
ArgumentCaptor<GenerateContentRequest> request =
623+
ArgumentCaptor.forClass(GenerateContentRequest.class);
624+
verify(mockServerStreamCallable).call(request.capture());
625+
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
626+
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
627+
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
628+
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
629+
}
572630
}

0 commit comments

Comments
 (0)