Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ public Optional<Type> visit(LogicalTypeAnnotation.JsonLogicalTypeAnnotation json
public Optional<Type> visit(LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonType) {
return Optional.of(Types.BinaryType.get());
}

@Override
public Optional<Type> visit(LogicalTypeAnnotation.UUIDLogicalTypeAnnotation uuidType) {
return Optional.of(Types.UUIDType.get());
}
}

private void addAlias(String name, int fieldId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1559,13 +1559,12 @@ public <D> CloseableIterable<D> build() {
} catch (IOException e) {
throw new RuntimeIOException(e);
}
Schema fileSchema = ParquetSchemaUtil.convert(type);
builder
.useStatsFilter()
.useDictionaryFilter()
.useRecordFilter(filterRecords)
.useBloomFilter()
.withFilter(ParquetFilters.convert(fileSchema, filter, caseSensitive));
.withFilter(ParquetFilters.convert(type, filter, caseSensitive));
} else {
// turn off filtering
builder
Expand Down
145 changes: 139 additions & 6 deletions parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
*/
package org.apache.iceberg.parquet;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.UUID;
import org.apache.iceberg.Schema;
import org.apache.iceberg.expressions.BoundPredicate;
import org.apache.iceberg.expressions.BoundReference;
Expand All @@ -29,19 +33,31 @@
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.expressions.Literal;
import org.apache.iceberg.expressions.UnboundPredicate;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.DecimalUtil;
import org.apache.iceberg.util.UUIDUtil;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.filter2.compat.FilterCompat;
import org.apache.parquet.filter2.predicate.FilterApi;
import org.apache.parquet.filter2.predicate.FilterPredicate;
import org.apache.parquet.filter2.predicate.Operators;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType;

class ParquetFilters {

private ParquetFilters() {}

static FilterCompat.Filter convert(Schema schema, Expression expr, boolean caseSensitive) {
static FilterCompat.Filter convert(
MessageType parquetSchema, Expression expr, boolean caseSensitive) {
Schema schema = ParquetSchemaUtil.convert(parquetSchema);
FilterPredicate pred =
ExpressionVisitors.visit(expr, new ConvertFilterToParquet(schema, caseSensitive));
ExpressionVisitors.visit(
expr,
new ConvertFilterToParquet(
schema, primitiveTypesById(parquetSchema, schema), caseSensitive));
// TODO: handle AlwaysFalse.INSTANCE
if (pred != null && pred != AlwaysTrue.INSTANCE) {
// FilterCompat will apply LogicalInverseRewriter
Expand All @@ -51,12 +67,30 @@ static FilterCompat.Filter convert(Schema schema, Expression expr, boolean caseS
}
}

private static Map<Integer, PrimitiveType> primitiveTypesById(
MessageType parquetSchema, Schema schema) {
Map<Integer, PrimitiveType> primitiveTypesById = Maps.newHashMap();

for (ColumnDescriptor desc : parquetSchema.getColumns()) {
PrimitiveType primitiveType = parquetSchema.getType(desc.getPath()).asPrimitiveType();
Integer fieldId = schema.aliasToId(String.join(".", desc.getPath()));
if (fieldId != null) {
primitiveTypesById.put(fieldId, primitiveType);
}
}

return primitiveTypesById;
}

private static class ConvertFilterToParquet extends ExpressionVisitor<FilterPredicate> {
private final Schema schema;
private final Map<Integer, PrimitiveType> primitiveTypesById;
private final boolean caseSensitive;

private ConvertFilterToParquet(Schema schema, boolean caseSensitive) {
private ConvertFilterToParquet(
Schema schema, Map<Integer, PrimitiveType> primitiveTypesById, boolean caseSensitive) {
this.schema = schema;
this.primitiveTypesById = primitiveTypesById;
this.caseSensitive = caseSensitive;
}

Expand Down Expand Up @@ -149,11 +183,18 @@ public <T> FilterPredicate predicate(BoundPredicate<T> pred) {
case DOUBLE:
return pred(op, FilterApi.doubleColumn(path), getParquetPrimitive(lit));
case STRING:
case UUID:
case FIXED:
case BINARY:
case DECIMAL:
return pred(op, FilterApi.binaryColumn(path), getParquetPrimitive(lit));
case UUID:
return pred(op, FilterApi.binaryColumn(path), getParquetUUID(lit));
case DECIMAL:
return decimalPred(
op,
path,
primitiveTypesById.get(ref.fieldId()),
(Types.DecimalType) ref.type().asPrimitiveType(),
lit);
}

throw new UnsupportedOperationException("Cannot convert to Parquet filter: " + pred);
Expand All @@ -173,6 +214,42 @@ public <T> FilterPredicate predicate(UnboundPredicate<T> pred) {
}
}

private static FilterPredicate decimalPred(
Operation op,
String path,
PrimitiveType primitiveType,
Types.DecimalType decimalType,
Literal<?> lit) {
if (primitiveType == null) {
return AlwaysTrue.INSTANCE;
}

BigDecimal decimal = decimalValue(decimalType, lit);
if (lit != null && decimal == null) {
return AlwaysTrue.INSTANCE;
}

try {
switch (primitiveType.getPrimitiveTypeName()) {
case INT32:
return pred(op, FilterApi.intColumn(path), getDecimalAsInt(decimal));
case INT64:
return pred(op, FilterApi.longColumn(path), getDecimalAsLong(decimal));
case FIXED_LEN_BYTE_ARRAY:
return pred(
op,
FilterApi.binaryColumn(path),
getDecimalAsFixed(decimalType, primitiveType.getTypeLength(), decimal));
case BINARY:
return pred(op, FilterApi.binaryColumn(path), getDecimalAsBinary(decimal));
default:
return AlwaysTrue.INSTANCE;
}
} catch (ArithmeticException e) {
return AlwaysTrue.INSTANCE;
}
}

@SuppressWarnings("checkstyle:MethodTypeParameterName")
private static <C extends Comparable<C>, COL extends Operators.Column<C> & Operators.SupportsLtGt>
FilterPredicate pred(Operation op, COL col, C value) {
Expand Down Expand Up @@ -214,13 +291,69 @@ FilterPredicate pred(Operation op, COL col, C value) {
}
}

private static Integer getDecimalAsInt(BigDecimal decimal) {
if (decimal == null) {
return null;
}

return decimal.unscaledValue().intValueExact();
}

private static Long getDecimalAsLong(BigDecimal decimal) {
if (decimal == null) {
return null;
}

return decimal.unscaledValue().longValueExact();
}

private static Binary getDecimalAsFixed(Types.DecimalType type, int length, BigDecimal decimal) {
if (decimal == null) {
return null;
}

byte[] bytes =
DecimalUtil.toReusedFixLengthBytes(
type.precision(), type.scale(), decimal, new byte[length]);
return Binary.fromConstantByteArray(bytes);
}

private static Binary getDecimalAsBinary(BigDecimal decimal) {
if (decimal == null) {
return null;
}

return Binary.fromConstantByteArray(decimal.unscaledValue().toByteArray());
}

private static BigDecimal decimalValue(Types.DecimalType type, Literal<?> lit) {
if (lit == null) {
return null;
}

BigDecimal decimal = (BigDecimal) lit.value();
try {
BigDecimal scaled = decimal.setScale(type.scale(), RoundingMode.UNNECESSARY);
return scaled.precision() <= type.precision() ? scaled : null;
} catch (ArithmeticException e) {
return null;
}
}

private static Binary getParquetUUID(Literal<?> lit) {
if (lit == null) {
return null;
}

return Binary.fromConstantByteArray(UUIDUtil.convert((UUID) lit.value()));
}

@SuppressWarnings("unchecked")
private static <C extends Comparable<C>> C getParquetPrimitive(Literal<?> lit) {
if (lit == null) {
return null;
}

// TODO: this needs to convert to handle BigDecimal and UUID
Object value = lit.value();
if (value instanceof Number) {
return (C) lit.value();
Expand Down
Loading