Skip to content

Commit a8bd704

Browse files
author
Constantin Muraru
committed
1 parent 5cf9248 commit a8bd704

5 files changed

Lines changed: 97 additions & 60 deletions

File tree

parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor
388388
if (parquetType.asGroupType().containsField("list")) {
389389
parquetSchema = parquetType.asGroupType().getType("list");
390390
if (parquetSchema.asGroupType().containsField("element")) {
391-
parquetSchema.asGroupType().getType("element");
391+
parquetSchema = parquetSchema.asGroupType().getType("element");
392392
}
393393
} else {
394394
throw new ParquetDecodingException("Expected list but got: " + parquetType);
@@ -403,10 +403,6 @@ public Converter getConverter(int fieldIndex) {
403403
throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper");
404404
}
405405

406-
if (listOfMessage) {
407-
return converter;
408-
}
409-
410406
return new GroupConverter() {
411407
@Override
412408
public Converter getConverter(int fieldIndex) {

parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*
1+
/*
22
* Licensed to the Apache Software Foundation (ASF) under one
33
* or more contributor license agreements. See the NOTICE file
44
* distributed with this work for additional information
@@ -19,6 +19,7 @@
1919
package org.apache.parquet.proto;
2020

2121
import com.google.protobuf.Descriptors;
22+
import com.google.protobuf.Descriptors.FieldDescriptor;
2223
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType;
2324
import com.google.protobuf.Message;
2425
import com.twitter.elephantbird.util.Protobufs;
@@ -59,8 +60,8 @@ public MessageType convert(Class<? extends Message> protobufClass) {
5960
}
6061

6162
/* Iterates over list of fields. **/
62-
private <T> GroupBuilder<T> convertFields(GroupBuilder<T> groupBuilder, List<Descriptors.FieldDescriptor> fieldDescriptors) {
63-
for (Descriptors.FieldDescriptor fieldDescriptor : fieldDescriptors) {
63+
private <T> GroupBuilder<T> convertFields(GroupBuilder<T> groupBuilder, List<FieldDescriptor> fieldDescriptors) {
64+
for (FieldDescriptor fieldDescriptor : fieldDescriptors) {
6465
groupBuilder =
6566
addField(fieldDescriptor, groupBuilder)
6667
.id(fieldDescriptor.getNumber())
@@ -69,7 +70,7 @@ private <T> GroupBuilder<T> convertFields(GroupBuilder<T> groupBuilder, List<Des
6970
return groupBuilder;
7071
}
7172

72-
private Type.Repetition getRepetition(Descriptors.FieldDescriptor descriptor) {
73+
private Type.Repetition getRepetition(FieldDescriptor descriptor) {
7374
if (descriptor.isRequired()) {
7475
return Type.Repetition.REQUIRED;
7576
} else if (descriptor.isRepeated()) {
@@ -79,7 +80,7 @@ private Type.Repetition getRepetition(Descriptors.FieldDescriptor descriptor) {
7980
}
8081
}
8182

82-
private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addField(Descriptors.FieldDescriptor descriptor, final GroupBuilder<T> builder) {
83+
private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addField(FieldDescriptor descriptor, final GroupBuilder<T> builder) {
8384
if (descriptor.getJavaType() == JavaType.MESSAGE) {
8485
return addMessageField(descriptor, builder);
8586
}
@@ -92,7 +93,7 @@ private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addF
9293
return builder.primitive(parquetType.primitiveType, getRepetition(descriptor)).as(parquetType.originalType);
9394
}
9495

95-
private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addRepeatedPrimitive(Descriptors.FieldDescriptor descriptor,
96+
private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addRepeatedPrimitive(FieldDescriptor descriptor,
9697
PrimitiveTypeName primitiveType,
9798
OriginalType originalType,
9899
final GroupBuilder<T> builder) {
@@ -104,18 +105,19 @@ private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addR
104105
.named("list");
105106
}
106107

107-
private <T> GroupBuilder<GroupBuilder<T>> addRepeatedMessage(Descriptors.FieldDescriptor descriptor, GroupBuilder<T> builder) {
108-
GroupBuilder<GroupBuilder<GroupBuilder<T>>> result =
108+
private <T> GroupBuilder<GroupBuilder<T>> addRepeatedMessage(FieldDescriptor descriptor, GroupBuilder<T> builder) {
109+
GroupBuilder<GroupBuilder<GroupBuilder<GroupBuilder<T>>>> result =
109110
builder
110111
.group(Type.Repetition.REQUIRED).as(OriginalType.LIST)
111-
.group(Type.Repetition.REPEATED);
112+
.group(Type.Repetition.REPEATED)
113+
.group(Type.Repetition.OPTIONAL);
112114

113115
convertFields(result, descriptor.getMessageType().getFields());
114116

115-
return result.named("list");
117+
return result.named("element").named("list");
116118
}
117119

118-
private <T> GroupBuilder<GroupBuilder<T>> addMessageField(Descriptors.FieldDescriptor descriptor, final GroupBuilder<T> builder) {
120+
private <T> GroupBuilder<GroupBuilder<T>> addMessageField(FieldDescriptor descriptor, final GroupBuilder<T> builder) {
119121
if (descriptor.isMapField()) {
120122
return addMapField(descriptor, builder);
121123
} else if (descriptor.isRepeated()) {
@@ -128,24 +130,24 @@ private <T> GroupBuilder<GroupBuilder<T>> addMessageField(Descriptors.FieldDescr
128130
return group;
129131
}
130132

131-
private <T> GroupBuilder<GroupBuilder<T>> addMapField(Descriptors.FieldDescriptor descriptor, final GroupBuilder<T> builder) {
132-
List<Descriptors.FieldDescriptor> fields = descriptor.getMessageType().getFields();
133+
private <T> GroupBuilder<GroupBuilder<T>> addMapField(FieldDescriptor descriptor, final GroupBuilder<T> builder) {
134+
List<FieldDescriptor> fields = descriptor.getMessageType().getFields();
133135
if (fields.size() != 2) {
134136
throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields);
135137
}
136138

137139
ParquetType mapKeyParquetType = getParquetType(fields.get(0));
138140

139141
GroupBuilder<GroupBuilder<GroupBuilder<T>>> group = builder
140-
.group(Type.Repetition.REQUIRED).as(OriginalType.MAP)
142+
.group(Type.Repetition.OPTIONAL).as(OriginalType.MAP) // only optional maps are allowed in Proto3
141143
.group(Type.Repetition.REPEATED) // key_value wrapper
142144
.primitive(mapKeyParquetType.primitiveType, Type.Repetition.REQUIRED).as(mapKeyParquetType.originalType).named("key");
143145

144146
return addField(fields.get(1), group).named("value")
145147
.named("key_value");
146148
}
147149

148-
private ParquetType getParquetType(Descriptors.FieldDescriptor fieldDescriptor) {
150+
private ParquetType getParquetType(FieldDescriptor fieldDescriptor) {
149151

150152
JavaType javaType = fieldDescriptor.getJavaType();
151153
switch (javaType) {

parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,17 @@
1818
*/
1919
package org.apache.parquet.proto;
2020

21-
import com.google.protobuf.ByteString;
22-
import com.google.protobuf.DescriptorProtos;
23-
import com.google.protobuf.Descriptors;
24-
import com.google.protobuf.MapEntry;
25-
import com.google.protobuf.Message;
26-
import com.google.protobuf.MessageOrBuilder;
27-
import com.google.protobuf.TextFormat;
21+
import com.google.protobuf.*;
22+
import com.google.protobuf.Descriptors.Descriptor;
23+
import com.google.protobuf.Descriptors.FieldDescriptor;
2824
import com.twitter.elephantbird.util.Protobufs;
2925
import org.apache.hadoop.conf.Configuration;
3026
import org.apache.parquet.hadoop.BadConfigurationException;
3127
import org.apache.parquet.hadoop.api.WriteSupport;
3228
import org.apache.parquet.io.InvalidRecordException;
3329
import org.apache.parquet.io.api.Binary;
3430
import org.apache.parquet.io.api.RecordConsumer;
35-
import org.apache.parquet.schema.GroupType;
36-
import org.apache.parquet.schema.IncompatibleSchemaModificationException;
37-
import org.apache.parquet.schema.MessageType;
38-
import org.apache.parquet.schema.OriginalType;
31+
import org.apache.parquet.schema.*;
3932
import org.apache.parquet.schema.Type;
4033
import org.slf4j.Logger;
4134
import org.slf4j.LoggerFactory;
@@ -113,7 +106,7 @@ public WriteContext init(Configuration configuration) {
113106
}
114107

115108
MessageType rootSchema = new ProtoSchemaConverter().convert(protoMessage);
116-
Descriptors.Descriptor messageDescriptor = Protobufs.getMessageDescriptor(protoMessage);
109+
Descriptor messageDescriptor = Protobufs.getMessageDescriptor(protoMessage);
117110
validatedMapping(messageDescriptor, rootSchema);
118111

119112
this.messageWriter = new MessageWriter(messageDescriptor, rootSchema);
@@ -156,11 +149,11 @@ class MessageWriter extends FieldWriter {
156149
final FieldWriter[] fieldWriters;
157150

158151
@SuppressWarnings("unchecked")
159-
MessageWriter(Descriptors.Descriptor descriptor, GroupType schema) {
160-
List<Descriptors.FieldDescriptor> fields = descriptor.getFields();
152+
MessageWriter(Descriptor descriptor, GroupType schema) {
153+
List<FieldDescriptor> fields = descriptor.getFields();
161154
fieldWriters = (FieldWriter[]) Array.newInstance(FieldWriter.class, fields.size());
162155

163-
for (Descriptors.FieldDescriptor fieldDescriptor: fields) {
156+
for (FieldDescriptor fieldDescriptor: fields) {
164157
String name = fieldDescriptor.getName();
165158
Type type = schema.getType(name);
166159
FieldWriter writer = createWriter(fieldDescriptor, type);
@@ -176,7 +169,7 @@ class MessageWriter extends FieldWriter {
176169
}
177170
}
178171

179-
private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) {
172+
private FieldWriter createWriter(FieldDescriptor fieldDescriptor, Type type) {
180173

181174
switch (fieldDescriptor.getJavaType()) {
182175
case STRING: return new StringWriter() ;
@@ -193,7 +186,7 @@ private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Ty
193186
return unknownType(fieldDescriptor);//should not be executed, always throws exception.
194187
}
195188

196-
private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) {
189+
private FieldWriter createMessageWriter(FieldDescriptor fieldDescriptor, Type type) {
197190
if (fieldDescriptor.isMapField()) {
198191
return createMapWriter(fieldDescriptor, type);
199192
}
@@ -203,7 +196,7 @@ private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescrip
203196

204197
private GroupType getGroupType(Type type) {
205198
if (type.getOriginalType() == OriginalType.LIST) {
206-
return type.asGroupType().getType("list").asGroupType();
199+
return type.asGroupType().getType("list").asGroupType().getType("element").asGroupType();
207200
}
208201

209202
if (type.getOriginalType() == OriginalType.MAP) {
@@ -213,20 +206,20 @@ private GroupType getGroupType(Type type) {
213206
return type.asGroupType();
214207
}
215208

216-
private MapWriter createMapWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) {
217-
List<Descriptors.FieldDescriptor> fields = fieldDescriptor.getMessageType().getFields();
209+
private MapWriter createMapWriter(FieldDescriptor fieldDescriptor, Type type) {
210+
List<FieldDescriptor> fields = fieldDescriptor.getMessageType().getFields();
218211
if (fields.size() != 2) {
219212
throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields);
220213
}
221214

222215
// KeyFieldWriter
223-
Descriptors.FieldDescriptor keyProtoField = fields.get(0);
216+
FieldDescriptor keyProtoField = fields.get(0);
224217
FieldWriter keyWriter = createWriter(keyProtoField, type);
225218
keyWriter.setFieldName(keyProtoField.getName());
226219
keyWriter.setIndex(0);
227220

228221
// ValueFieldWriter
229-
Descriptors.FieldDescriptor valueProtoField = fields.get(1);
222+
FieldDescriptor valueProtoField = fields.get(1);
230223
FieldWriter valueWriter = createWriter(valueProtoField, type);
231224
valueWriter.setFieldName(valueProtoField.getName());
232225
valueWriter.setIndex(1);
@@ -257,10 +250,10 @@ final void writeField(Object value) {
257250

258251
private void writeAllFields(MessageOrBuilder pb) {
259252
//returns changed fields with values. Map is ordered by id.
260-
Map<Descriptors.FieldDescriptor, Object> changedPbFields = pb.getAllFields();
253+
Map<FieldDescriptor, Object> changedPbFields = pb.getAllFields();
261254

262-
for (Map.Entry<Descriptors.FieldDescriptor, Object> entry : changedPbFields.entrySet()) {
263-
Descriptors.FieldDescriptor fieldDescriptor = entry.getKey();
255+
for (Map.Entry<FieldDescriptor, Object> entry : changedPbFields.entrySet()) {
256+
FieldDescriptor fieldDescriptor = entry.getKey();
264257

265258
if(fieldDescriptor.isExtension()) {
266259
// Field index of an extension field might overlap with a base field.
@@ -295,13 +288,21 @@ final void writeField(Object value) {
295288
recordConsumer.startField("list", 0); // This is the wrapper group for the array field
296289
for (Object listEntry: list) {
297290
recordConsumer.startGroup();
298-
if (isPrimitive(listEntry)) {
299-
recordConsumer.startField("element", 0);
291+
292+
recordConsumer.startField("element", 0); // This is the mandatory inner field
293+
294+
if (!isPrimitive(listEntry)) {
295+
recordConsumer.startGroup();
300296
}
297+
301298
fieldWriter.writeRawValue(listEntry);
302-
if (isPrimitive(listEntry)) {
303-
recordConsumer.endField("element", 0);
299+
300+
if (!isPrimitive(listEntry)) {
301+
recordConsumer.endGroup();
304302
}
303+
304+
recordConsumer.endField("element", 0);
305+
305306
recordConsumer.endGroup();
306307
}
307308
recordConsumer.endField("list", 0);
@@ -316,10 +317,10 @@ private boolean isPrimitive(Object listEntry) {
316317
}
317318

318319
/** validates mapping between protobuffer fields and parquet fields.*/
319-
private void validatedMapping(Descriptors.Descriptor descriptor, GroupType parquetSchema) {
320-
List<Descriptors.FieldDescriptor> allFields = descriptor.getFields();
320+
private void validatedMapping(Descriptor descriptor, GroupType parquetSchema) {
321+
List<FieldDescriptor> allFields = descriptor.getFields();
321322

322-
for (Descriptors.FieldDescriptor fieldDescriptor: allFields) {
323+
for (FieldDescriptor fieldDescriptor: allFields) {
323324
String fieldName = fieldDescriptor.getName();
324325
int fieldIndex = fieldDescriptor.getIndex();
325326
int parquetIndex = parquetSchema.getFieldIndex(fieldName);
@@ -370,10 +371,16 @@ final void writeRawValue(Object value) {
370371
recordConsumer.startGroup();
371372

372373
recordConsumer.startField("key_value", 0); // This is the wrapper group for the map field
373-
for(MapEntry<?, ?> entry : (Collection<MapEntry<?, ?>>) value) {
374+
for (Message msg : (Collection<Message>) value) {
374375
recordConsumer.startGroup();
375-
keyWriter.writeField(entry.getKey());
376-
valueWriter.writeField(entry.getValue());
376+
377+
final Descriptor descriptorForType = msg.getDescriptorForType();
378+
final FieldDescriptor keyDesc = descriptorForType.findFieldByName("key");
379+
final FieldDescriptor valueDesc = descriptorForType.findFieldByName("value");
380+
381+
keyWriter.writeField(msg.getField(keyDesc));
382+
valueWriter.writeField(msg.getField(valueDesc));
383+
377384
recordConsumer.endGroup();
378385
}
379386

@@ -421,15 +428,15 @@ final void writeRawValue(Object value) {
421428
}
422429
}
423430

424-
private FieldWriter unknownType(Descriptors.FieldDescriptor fieldDescriptor) {
431+
private FieldWriter unknownType(FieldDescriptor fieldDescriptor) {
425432
String exceptionMsg = "Unknown type with descriptor \"" + fieldDescriptor
426433
+ "\" and type \"" + fieldDescriptor.getJavaType() + "\".";
427434
throw new InvalidRecordException(exceptionMsg);
428435
}
429436

430437
/** Returns message descriptor as JSON String*/
431438
private String serializeDescriptor(Class<? extends Message> protoClass) {
432-
Descriptors.Descriptor descriptor = Protobufs.getMessageDescriptor(protoClass);
439+
Descriptor descriptor = Protobufs.getMessageDescriptor(protoClass);
433440
DescriptorProtos.DescriptorProto asProto = descriptor.toProto();
434441
return TextFormat.printToString(asProto);
435442
}

parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public void testProto3ConvertAllDatatypes() throws Exception {
103103
" optional binary optionalEnum (ENUM) = 18;" +
104104
" optional int32 someInt32 = 19;" +
105105
" optional binary someString (UTF8) = 20;" +
106-
" required group optionalMap (MAP) = 21 {\n" +
106+
" optional group optionalMap (MAP) = 21 {\n" +
107107
" repeated group key_value {\n" +
108108
" required int64 key;\n" +
109109
" optional group value {\n" +
@@ -135,7 +135,9 @@ public void testConvertRepetition() throws Exception {
135135
" }\n" +
136136
" required group repeatedMessage (LIST) = 9 {\n" +
137137
" repeated group list {\n" +
138-
" optional int32 someId = 3;\n" +
138+
" optional group element {\n" +
139+
" optional int32 someId = 3;\n" +
140+
" }\n" +
139141
" }\n" +
140142
" }" +
141143
"}";
@@ -158,7 +160,9 @@ public void testProto3ConvertRepetition() throws Exception {
158160
" }\n" +
159161
" required group repeatedMessage (LIST) = 9 {\n" +
160162
" repeated group list {\n" +
161-
" optional int32 someId = 3;\n" +
163+
" optional group element {\n" +
164+
" optional int32 someId = 3;\n" +
165+
" }\n" +
162166
" }\n" +
163167
" }\n" +
164168
"}";

0 commit comments

Comments
 (0)