3636import com .google .common .base .Supplier ;
3737import com .google .common .base .Suppliers ;
3838import com .google .common .collect .ImmutableList ;
39+ import com .google .common .collect .ImmutableMap ;
3940import com .google .errorprone .annotations .CanIgnoreReturnValue ;
4041import java .io .IOException ;
42+ import java .util .HashMap ;
4143import java .util .List ;
44+ import java .util .Map ;
4245import java .util .Optional ;
4346import java .util .logging .Level ;
4447import java .util .logging .Logger ;
@@ -63,6 +66,7 @@ public class VertexAI implements AutoCloseable {
6366 private final String location ;
6467 private final String apiEndpoint ;
6568 private final Transport transport ;
69+ private final HeaderProvider headerProvider ;
6670 private final CredentialsProvider credentialsProvider ;
6771
6872 private final transient Supplier <PredictionServiceClient > predictionClientSupplier ;
@@ -85,6 +89,7 @@ public VertexAI(String projectId, String location) {
8589 location ,
8690 Transport .GRPC ,
8791 ImmutableList .of (),
92+ /* customHeaders= */ ImmutableMap .of (),
8893 /* credentials= */ Optional .empty (),
8994 /* apiEndpoint= */ Optional .empty (),
9095 /* predictionClientSupplierOpt= */ Optional .empty (),
@@ -108,6 +113,7 @@ public VertexAI() {
108113 null ,
109114 Transport .GRPC ,
110115 ImmutableList .of (),
116+ /* customHeaders= */ ImmutableMap .of (),
111117 /* credentials= */ Optional .empty (),
112118 /* apiEndpoint= */ Optional .empty (),
113119 /* predictionClientSupplierOpt= */ Optional .empty (),
@@ -119,6 +125,7 @@ private VertexAI(
119125 String location ,
120126 Transport transport ,
121127 List <String > scopes ,
128+ Map <String , String > customHeaders ,
122129 Optional <Credentials > credentials ,
123130 Optional <String > apiEndpoint ,
124131 Optional <Supplier <PredictionServiceClient >> predictionClientSupplierOpt ,
@@ -131,6 +138,15 @@ private VertexAI(
131138 this .location = Strings .isNullOrEmpty (location ) ? inferLocation () : location ;
132139 this .transport = transport ;
133140
141+ String sdkHeader =
142+ String .format (
143+ "%s/%s" ,
144+ Constants .USER_AGENT_HEADER ,
145+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class ));
146+ Map <String , String > headers = new HashMap <>(customHeaders );
147+ headers .compute ("user-agent" , (k , v ) -> v == null ? sdkHeader : sdkHeader + " " + v );
148+ this .headerProvider = FixedHeaderProvider .create (headers );
149+
134150 if (credentials .isPresent ()) {
135151 this .credentialsProvider = FixedCredentialsProvider .create (credentials .get ());
136152 } else {
@@ -160,6 +176,7 @@ public static class Builder {
160176 private String location ;
161177 private Transport transport = Transport .GRPC ;
162178 private ImmutableList <String > scopes = ImmutableList .of ();
179+ private ImmutableMap <String , String > customHeaders = ImmutableMap .of ();
163180 private Optional <Credentials > credentials = Optional .empty ();
164181 private Optional <String > apiEndpoint = Optional .empty ();
165182
@@ -174,6 +191,7 @@ public VertexAI build() {
174191 location ,
175192 transport ,
176193 scopes ,
194+ customHeaders ,
177195 credentials ,
178196 apiEndpoint ,
179197 Optional .ofNullable (predictionClientSupplier ),
@@ -240,6 +258,14 @@ public Builder setScopes(List<String> scopes) {
240258 this .scopes = ImmutableList .copyOf (scopes );
241259 return this ;
242260 }
261+
262+ @ CanIgnoreReturnValue
263+ public Builder setCustomHeaders (Map <String , String > customHeaders ) {
264+ checkNotNull (customHeaders , "customHeaders can't be null" );
265+
266+ this .customHeaders = ImmutableMap .copyOf (customHeaders );
267+ return this ;
268+ }
243269 }
244270
245271 /**
@@ -278,6 +304,15 @@ public String getApiEndpoint() {
278304 return apiEndpoint ;
279305 }
280306
307+ /**
308+ * Returns the headers to use when making API calls.
309+ *
310+ * @return a map of headers to use when making API calls.
311+ */
312+ public Map <String , String > getHeaders () {
313+ return headerProvider .getHeaders ();
314+ }
315+
281316 /**
282317 * Returns the default credentials to use when making API calls.
283318 *
0 commit comments