Skip to content

Commit f024111

Browse files
BREAKING_CHANGE: [vertexai] remove Transport from GenerativeModel (#10530)
PiperOrigin-RevId: 615144883 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent e153330 commit f024111

File tree

22 files changed

+8102
-160
lines changed

22 files changed

+8102
-160
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java

Lines changed: 6 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ public class GenerativeModel {
4242
private GenerationConfig generationConfig = null;
4343
private List<SafetySetting> safetySettings = null;
4444
private List<Tool> tools = null;
45-
private Transport transport;
4645

4746
public static Builder newBuilder() {
4847
return new Builder();
@@ -67,12 +66,6 @@ private GenerativeModel(Builder builder) {
6766
if (builder.tools != null) {
6867
this.tools = builder.tools;
6968
}
70-
71-
if (builder.transport != null) {
72-
this.transport = builder.transport;
73-
} else {
74-
this.transport = this.vertexAi.getTransport();
75-
}
7669
}
7770

7871
/** Builder class for {@link GenerativeModel}. */
@@ -82,7 +75,6 @@ public static class Builder {
8275
private GenerationConfig generationConfig;
8376
private List<SafetySetting> safetySettings;
8477
private List<Tool> tools;
85-
private Transport transport;
8678

8779
private Builder() {}
8880

@@ -158,15 +150,6 @@ public Builder setTools(List<Tool> tools) {
158150
}
159151
return this;
160152
}
161-
162-
/**
163-
* Sets the {@link Transport} layer for API calls in the generative model. It overrides the
164-
* transport setting in {@link com.google.cloud.vertexai.VertexAI}
165-
*/
166-
public Builder setTransport(Transport transport) {
167-
this.transport = transport;
168-
return this;
169-
}
170153
}
171154

172155
/**
@@ -180,21 +163,7 @@ public Builder setTransport(Transport transport) {
180163
* for the generative model
181164
*/
182165
public GenerativeModel(String modelName, VertexAI vertexAi) {
183-
this(modelName, null, null, vertexAi, null);
184-
}
185-
186-
/**
187-
* Constructs a GenerativeModel instance.
188-
*
189-
* @param modelName the name of the generative model. Supported format: "gemini-pro",
190-
* "models/gemini-pro", "publishers/google/models/gemini-pro"
191-
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
192-
* for the generative model
193-
* @param transport the {@link Transport} layer for API calls in the generative model. It
194-
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
195-
*/
196-
public GenerativeModel(String modelName, VertexAI vertexAi, Transport transport) {
197-
this(modelName, null, null, vertexAi, transport);
166+
this(modelName, null, null, vertexAi);
198167
}
199168

200169
/**
@@ -209,25 +178,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi, Transport transport)
209178
*/
210179
@BetaApi
211180
public GenerativeModel(String modelName, GenerationConfig generationConfig, VertexAI vertexAi) {
212-
this(modelName, generationConfig, null, vertexAi, null);
213-
}
214-
215-
/**
216-
* Constructs a GenerativeModel instance with default generation config.
217-
*
218-
* @param modelName the name of the generative model. Supported format: "gemini-pro",
219-
* "models/gemini-pro", "publishers/google/models/gemini-pro"
220-
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} instance that
221-
* will be used by default for generating response
222-
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
223-
* for the generative model
224-
* @param transport the {@link Transport} layer for API calls in the generative model. It
225-
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
226-
*/
227-
@BetaApi
228-
public GenerativeModel(
229-
String modelName, GenerationConfig generationConfig, VertexAI vertexAi, Transport transport) {
230-
this(modelName, generationConfig, null, vertexAi, transport);
181+
this(modelName, generationConfig, null, vertexAi);
231182
}
232183

233184
/**
@@ -242,28 +193,7 @@ public GenerativeModel(
242193
*/
243194
@BetaApi("safetySettings is a preview feature.")
244195
public GenerativeModel(String modelName, List<SafetySetting> safetySettings, VertexAI vertexAi) {
245-
this(modelName, null, safetySettings, vertexAi, null);
246-
}
247-
248-
/**
249-
* Constructs a GenerativeModel instance with default safety settings.
250-
*
251-
* @param modelName the name of the generative model. Supported format: "gemini-pro",
252-
* "models/gemini-pro", "publishers/google/models/gemini-pro"
253-
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} instances
254-
* that will be used by default for generating response
255-
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
256-
* for the generative model
257-
* @param transport the {@link Transport} layer for API calls in the generative model. It
258-
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
259-
*/
260-
@BetaApi("safetySettings is a preview feature.")
261-
public GenerativeModel(
262-
String modelName,
263-
List<SafetySetting> safetySettings,
264-
VertexAI vertexAi,
265-
Transport transport) {
266-
this(modelName, null, safetySettings, vertexAi, transport);
196+
this(modelName, null, safetySettings, vertexAi);
267197
}
268198

269199
/**
@@ -284,30 +214,6 @@ public GenerativeModel(
284214
GenerationConfig generationConfig,
285215
List<SafetySetting> safetySettings,
286216
VertexAI vertexAi) {
287-
this(modelName, generationConfig, safetySettings, vertexAi, null);
288-
}
289-
290-
/**
291-
* Constructs a GenerativeModel instance with default generation config and safety settings.
292-
*
293-
* @param modelName the name of the generative model. Supported format: "gemini-pro",
294-
* "models/gemini-pro", "publishers/google/models/gemini-pro"
295-
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} instance that
296-
* will be used by default for generating response
297-
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} instances
298-
* that will be used by default for generating response
299-
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
300-
* for the generative model
301-
* @param transport the {@link Transport} layer for API calls in the generative model. It
302-
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
303-
*/
304-
@BetaApi
305-
public GenerativeModel(
306-
String modelName,
307-
GenerationConfig generationConfig,
308-
List<SafetySetting> safetySettings,
309-
VertexAI vertexAi,
310-
Transport transport) {
311217
modelName = reconcileModelName(modelName);
312218
this.modelName = modelName;
313219
this.resourceName =
@@ -324,11 +230,6 @@ public GenerativeModel(
324230
}
325231
}
326232
this.vertexAi = vertexAi;
327-
if (transport != null) {
328-
this.transport = transport;
329-
} else {
330-
this.transport = vertexAi.getTransport();
331-
}
332233
}
333234

334235
/**
@@ -388,7 +289,7 @@ public CountTokensResponse countTokens(List<Content> contents) throws IOExceptio
388289
@BetaApi
389290
private CountTokensResponse countTokensFromRequest(CountTokensRequest request)
390291
throws IOException {
391-
if (this.transport == Transport.REST) {
292+
if (vertexAi.getTransport() == Transport.REST) {
392293
return vertexAi.getLlmUtilityRestClient().countTokens(request);
393294
} else {
394295
return vertexAi.getLlmUtilityClient().countTokens(request);
@@ -619,7 +520,7 @@ public GenerateContentResponse generateContent(
619520
*/
620521
private GenerateContentResponse generateContent(GenerateContentRequest request)
621522
throws IOException {
622-
if (this.transport == Transport.REST) {
523+
if (vertexAi.getTransport() == Transport.REST) {
623524
return vertexAi.getPredictionServiceRestClient().generateContentCallable().call(request);
624525
} else {
625526
return vertexAi.getPredictionServiceClient().generateContentCallable().call(request);
@@ -1031,7 +932,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
1031932
*/
1032933
private ResponseStream<GenerateContentResponse> generateContentStream(
1033934
GenerateContentRequest request) throws IOException {
1034-
if (this.transport == Transport.REST) {
935+
if (vertexAi.getTransport() == Transport.REST) {
1035936
return new ResponseStream(
1036937
new ResponseStreamIteratorWithHistory(
1037938
vertexAi
@@ -1082,24 +983,11 @@ public void setTools(List<Tool> tools) {
1082983
}
1083984
}
1084985

1085-
/**
1086-
* Sets the value for {@link #getTransport}, which defines the layer for API calls in this
1087-
* generative model.
1088-
*/
1089-
public void setTransport(Transport transport) {
1090-
this.transport = transport;
1091-
}
1092-
1093986
/** Returns the model name of this generative model. */
1094987
public String getModelName() {
1095988
return this.modelName;
1096989
}
1097990

1098-
/** Returns the {@link Transport} layer for API calls in this generative model. */
1099-
public Transport getTransport() {
1100-
return this.transport;
1101-
}
1102-
1103991
/**
1104992
* Returns the {@link com.google.cloud.vertexai.api.GenerationConfig} of this generative model.
1105993
*/

java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import com.google.api.gax.rpc.ServerStreamingCallable;
2727
import com.google.api.gax.rpc.UnaryCallable;
2828
import com.google.auth.oauth2.GoogleCredentials;
29-
import com.google.cloud.vertexai.Transport;
3029
import com.google.cloud.vertexai.VertexAI;
3130
import com.google.cloud.vertexai.api.Content;
3231
import com.google.cloud.vertexai.api.CountTokensRequest;
@@ -35,15 +34,18 @@
3534
import com.google.cloud.vertexai.api.GenerateContentRequest;
3635
import com.google.cloud.vertexai.api.GenerateContentResponse;
3736
import com.google.cloud.vertexai.api.GenerationConfig;
37+
import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
3838
import com.google.cloud.vertexai.api.HarmCategory;
3939
import com.google.cloud.vertexai.api.LlmUtilityServiceClient;
4040
import com.google.cloud.vertexai.api.Part;
4141
import com.google.cloud.vertexai.api.PredictionServiceClient;
42+
import com.google.cloud.vertexai.api.Retrieval;
4243
import com.google.cloud.vertexai.api.SafetySetting;
4344
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
4445
import com.google.cloud.vertexai.api.Schema;
4546
import com.google.cloud.vertexai.api.Tool;
4647
import com.google.cloud.vertexai.api.Type;
48+
import com.google.cloud.vertexai.api.VertexAISearch;
4749
import java.lang.reflect.Field;
4850
import java.util.Arrays;
4951
import java.util.Iterator;
@@ -96,14 +98,30 @@ public final class GenerativeModelTest {
9698
.build())
9799
.addRequired("location")))
98100
.build();
101+
private static final Tool GOOGLE_SEARCH_TOOL =
102+
Tool.newBuilder()
103+
.setGoogleSearchRetrieval(GoogleSearchRetrieval.newBuilder().setDisableAttribution(false))
104+
.build();
105+
private static final Tool VERTEX_AI_SEARCH_TOOL =
106+
Tool.newBuilder()
107+
.setRetrieval(
108+
Retrieval.newBuilder()
109+
.setVertexAiSearch(
110+
VertexAISearch.newBuilder()
111+
.setDatastore(
112+
String.format(
113+
"projects/%s/locations/%s/collections/%s/dataStores/%s",
114+
PROJECT, "global", "default_collection", "test_123")))
115+
.setDisableAttribution(false))
116+
.build();
99117

100118
private static final String TEXT = "What is your name?";
101119

102120
private VertexAI vertexAi;
103121
private GenerativeModel model;
104122
private List<SafetySetting> safetySettings = Arrays.asList(SAFETY_SETTING);
105123
private List<SafetySetting> defaultSafetySettings = Arrays.asList(DEFAULT_SAFETY_SETTING);
106-
private List<Tool> tools = Arrays.asList(TOOL);
124+
private List<Tool> tools = Arrays.asList(TOOL, GOOGLE_SEARCH_TOOL, VERTEX_AI_SEARCH_TOOL);
107125

108126
@Rule public final MockitoRule mocksRule = MockitoJUnit.rule();
109127

@@ -169,7 +187,6 @@ public void testInstantiateGenerativeModelwithBuilder() {
169187
assertThat(model.getGenerationConfig()).isNull();
170188
assertThat(model.getSafetySettings()).isNull();
171189
assertThat(model.getTools()).isNull();
172-
assertThat(model.getTransport()).isEqualTo(Transport.GRPC);
173190
}
174191

175192
@Test
@@ -181,13 +198,11 @@ public void testInstantiateGenerativeModelwithBuilderAllConfigs() {
181198
.setGenerationConfig(GENERATION_CONFIG)
182199
.setSafetySettings(safetySettings)
183200
.setTools(tools)
184-
.setTransport(Transport.REST)
185201
.build();
186202
assertThat(model.getModelName()).isEqualTo(MODEL_NAME);
187203
assertThat(model.getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
188204
assertThat(model.getSafetySettings()).isEqualTo(safetySettings);
189205
assertThat(model.getTools()).isEqualTo(tools);
190-
assertThat(model.getTransport()).isEqualTo(Transport.REST);
191206
}
192207

193208
@Test

0 commit comments

Comments
 (0)