diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index 601f50a4d001..83aa691482e5 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -315,12 +315,51 @@ pub(crate) fn cast_to_dictionary( FixedSizeBinary(byte_size) => { pack_byte_to_fixed_size_dictionary::(array, cast_options, byte_size) } + Struct(_) => pack_struct_to_dictionary::(array, dict_value_type, cast_options), _ => Err(ArrowError::CastError(format!( "Unsupported output type for dictionary packing: {dict_value_type}" ))), } } +/// Wrap a struct-valued array as a `DictionaryArray` with identity +/// keys `[0, 1, ..., len-1]`. Unlike the primitive / byte packers above, no +/// deduplication is performed, since struct values have no general hash/equality +/// builder in arrow-rs. +/// +/// Each child field of the source is recursively cast to the matching field of +/// `dict_value_type` via `cast_with_options` before keys are emitted. If any +/// child cast fails, the whole pack fails, the same contract as the primitive +/// packers above. +fn pack_struct_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let cast_values = cast_with_options(array, dict_value_type, cast_options)?; + let len = cast_values.len(); + + // Identity keys `[0, 1, ..., len-1]`, with null entries wherever the + // source row is null so the dictionary's logical null mask matches. + let mut builder = PrimitiveBuilder::::with_capacity(len); + for i in 0..len { + if cast_values.is_null(i) { + builder.append_null(); + } else { + let key = K::Native::from_usize(i).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot fit {len} dictionary keys in {:?}", + K::DATA_TYPE, + )) + })?; + builder.append_value(key); + } + } + let keys = builder.finish(); + + Ok(Arc::new(DictionaryArray::::try_new(keys, cast_values)?)) +} + // Packs the data from the primitive array of type to a // DictionaryArray with keys of type K and values of value_type V pub(crate) fn pack_numeric_to_dictionary( diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 67da85b8c1d6..6a7d31bb2a2e 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -6278,6 +6278,132 @@ mod tests { assert_ne!(keys.value(0), keys.value(1)); } + #[test] + fn test_cast_struct_array_to_dict_struct() { + // Cast a StructArray into Dictionary. The dictionary + // value type's child fields may differ from the source's (here: + // Utf8 source → Utf8View child for `name`), so the per-field cast + // must run before identity keys are emitted. This is the "as long as + // the struct can be cast to the dict value" contract. + let names = StringArray::from(vec![Some("alpha"), None, Some("gamma")]); + let ids = Int32Array::from(vec![Some(1), Some(2), Some(3)]); + let source = StructArray::from(vec![ + ( + Arc::new(Field::new("name", DataType::Utf8, true)), + Arc::new(names) as ArrayRef, + ), + ( + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(ids) as ArrayRef, + ), + ]); + + let target_value_type = DataType::Struct( + vec![ + Field::new("name", DataType::Utf8View, true), + Field::new("id", DataType::Int64, false), + ] + .into(), + ); + let cast_type = DataType::Dictionary( + Box::new(DataType::UInt32), + Box::new(target_value_type.clone()), + ); + assert!(can_cast_types(source.data_type(), &cast_type)); + + let cast_array = cast(&source, &cast_type).unwrap(); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array.len(), 3); + + let dict = cast_array.as_dictionary::(); + assert_eq!(dict.values().data_type(), &target_value_type); + // No dedup is performed for struct values — one row, one key. + assert_eq!(dict.values().len(), 3); + + // Source row 1 was a `Utf8`-null in the `name` field but the whole + // struct row was valid (StructArray::from above takes per-field + // nulls only). The dictionary's logical null mask therefore mirrors + // the source struct's row-level null mask — all rows valid here. + let keys = dict.keys(); + assert_eq!(keys.values(), &[0u32, 1, 2]); + assert_eq!(keys.null_count(), 0); + + let struct_values = dict.values().as_struct(); + let names_out = struct_values + .column_by_name("name") + .unwrap() + .as_string_view(); + assert_eq!(names_out.value(0), "alpha"); + assert!(names_out.is_null(1)); + assert_eq!(names_out.value(2), "gamma"); + let ids_out = struct_values + .column_by_name("id") + .unwrap() + .as_primitive::(); + assert_eq!(ids_out.values(), &[1i64, 2, 3]); + } + + #[test] + fn test_cast_struct_array_to_dict_struct_row_nulls() { + // Row-level nulls on the source struct must surface as null keys on + // the dictionary, since the dictionary's logical null mask is + // determined by the keys. + let names = StringArray::from(vec![Some("alpha"), Some("beta"), Some("gamma")]); + let ids = Int32Array::from(vec![Some(1), Some(2), Some(3)]); + let source = StructArray::try_new( + vec![ + Field::new("name", DataType::Utf8, true), + Field::new("id", DataType::Int32, false), + ] + .into(), + vec![Arc::new(names) as ArrayRef, Arc::new(ids) as ArrayRef], + Some(NullBuffer::from(vec![true, false, true])), + ) + .unwrap(); + + let target_value_type = DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), + Field::new("id", DataType::Int32, false), + ] + .into(), + ); + let cast_type = + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(target_value_type)); + + let cast_array = cast(&source, &cast_type).unwrap(); + let dict = cast_array.as_dictionary::(); + assert_eq!(dict.len(), 3); + let keys = dict.keys(); + assert!(!keys.is_null(0)); + assert!(keys.is_null(1)); + assert!(!keys.is_null(2)); + } + + #[test] + fn test_cast_struct_array_to_dict_struct_key_overflow() { + // Source has 300 rows but the dictionary key type is UInt8 (max 255). + // We must return a CastError instead of silently truncating. + let n = 300; + let names = StringArray::from((0..n).map(|i| Some(format!("v{i}"))).collect::>()); + let source = StructArray::from(vec![( + Arc::new(Field::new("name", DataType::Utf8, true)), + Arc::new(names) as ArrayRef, + )]); + + let cast_type = DataType::Dictionary( + Box::new(DataType::UInt8), + Box::new(DataType::Struct( + vec![Field::new("name", DataType::Utf8, true)].into(), + )), + ); + let err = cast(&source, &cast_type).unwrap_err().to_string(); + assert!( + err.contains("Cannot fit") && err.contains("dictionary keys"), + "expected key-overflow error, got: {err}" + ); + } + #[test] fn test_cast_empty_string_array_to_dict_utf8_view() { let array = StringArray::from(Vec::>::new());