diff --git a/DuckDB.NET.Bindings/NativeMethods/NativeMethods.Query.cs b/DuckDB.NET.Bindings/NativeMethods/NativeMethods.Query.cs index 8b23450a..73d6c54d 100644 --- a/DuckDB.NET.Bindings/NativeMethods/NativeMethods.Query.cs +++ b/DuckDB.NET.Bindings/NativeMethods/NativeMethods.Query.cs @@ -57,5 +57,8 @@ public static class Query [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_result_error_type")] public static extern DuckDBErrorType DuckDBResultErrorType(ref DuckDBResult result); + + [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_table_names")] + public static extern DuckDBValue DuckDBGetTableNames(DuckDBNativeConnection connection, string query, bool qualified); } } \ No newline at end of file diff --git a/DuckDB.NET.Bindings/NativeMethods/NativeMethods.Value.cs b/DuckDB.NET.Bindings/NativeMethods/NativeMethods.Value.cs index a5f106d2..7a6f5e8c 100644 --- a/DuckDB.NET.Bindings/NativeMethods/NativeMethods.Value.cs +++ b/DuckDB.NET.Bindings/NativeMethods/NativeMethods.Value.cs @@ -170,6 +170,12 @@ public static class Value [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_is_null_value")] public static extern bool DuckDBIsNullValue(DuckDBValue value); + [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_list_size")] + public static extern ulong DuckDBGetListSize(DuckDBValue value); + + [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_get_list_child")] + public static extern DuckDBValue DuckDBGetListChild(DuckDBValue value, ulong index); + public static DuckDBValue DuckDBCreateListValue(DuckDBLogicalType logicalType, DuckDBValue[] values, int count) { var duckDBValue = DuckDBCreateListValue(logicalType, values.Select(item => item.DangerousGetHandle()).ToArray(), count); diff --git a/DuckDB.NET.Data/DuckDBDataReader.cs b/DuckDB.NET.Data/DuckDBDataReader.cs index 75094fa2..721fa07c 100644 --- a/DuckDB.NET.Data/DuckDBDataReader.cs +++ b/DuckDB.NET.Data/DuckDBDataReader.cs @@ -309,15 +309,26 @@ public override DataTable GetSchemaTable() { SchemaTableColumn.NumericPrecision, typeof(byte)}, { SchemaTableColumn.NumericScale, typeof(byte) }, { SchemaTableColumn.DataType, typeof(Type) }, - { SchemaTableColumn.AllowDBNull, typeof(bool) } + { SchemaTableColumn.AllowDBNull, typeof(bool) }, + { SchemaTableColumn.BaseSchemaName, typeof(string) }, + { SchemaTableColumn.BaseTableName, typeof(string) }, + { SchemaTableColumn.BaseColumnName, typeof(string) } } }; - var rowData = new object[7]; + // Get table names from the query + // Note: DuckDB's duckdb_get_table_names returns unique table names referenced in the query, + // not per-column mappings. For single-table queries, we can populate BaseTableName. + // For multi-table queries (joins), the mapping is not directly available from the C API. + var tableNames = GetTableNamesFromQuery(); + var singleTableName = tableNames != null && tableNames.Length == 1 ? tableNames[0] : null; + + var rowData = new object[10]; for (var i = 0; i < FieldCount; i++) { - rowData[0] = GetName(i); + var columnName = GetName(i); + rowData[0] = columnName; rowData[1] = i; rowData[2] = -1; rowData[5] = GetFieldType(i); @@ -333,12 +344,99 @@ public override DataTable GetSchemaTable() rowData[3] = rowData[4] = DBNull.Value; } + // Set table name information + // For single-table queries, populate the table name + // For multi-table queries, BaseTableName will be DBNull since we cannot determine + // which table each column comes from without additional API support + if (!string.IsNullOrEmpty(singleTableName)) + { + // The table name from duckdb_get_table_names with qualified=true should be in the format "schema.table" + // Split by the last dot to handle schema names or table names that might contain dots + var lastDotIndex = singleTableName.LastIndexOf('.'); + if (lastDotIndex > 0 && lastDotIndex < singleTableName.Length - 1) + { + rowData[7] = singleTableName.Substring(0, lastDotIndex); // BaseSchemaName + rowData[8] = singleTableName.Substring(lastDotIndex + 1); // BaseTableName + } + else + { + // No schema qualifier found, just use the table name + rowData[7] = DBNull.Value; // BaseSchemaName + rowData[8] = singleTableName; // BaseTableName + } + } + else + { + rowData[7] = DBNull.Value; + rowData[8] = DBNull.Value; + } + + rowData[9] = columnName; // BaseColumnName + table.Rows.Add(rowData); } return table; } + private string[]? GetTableNamesFromQuery() + { + try + { + var duckDBConnection = command?.Connection as DuckDBConnection; + if (duckDBConnection?.NativeConnection == null || command?.CommandText == null || string.IsNullOrEmpty(command.CommandText)) + { + return null; + } + + // Call duckdb_get_table_names with qualified=true to get schema-qualified names + using var tableNamesValue = NativeMethods.Query.DuckDBGetTableNames( + duckDBConnection.NativeConnection, + command.CommandText, + true); + + if (tableNamesValue.IsNull()) + { + return null; + } + + // Get the size of the list + var listSize = NativeMethods.Value.DuckDBGetListSize(tableNamesValue); + + // If the list is empty, return null + if (listSize == 0) + { + return null; + } + + var tableNames = new string[listSize]; + + // Extract each table name from the list + for (ulong i = 0; i < listSize; i++) + { + using var childValue = NativeMethods.Value.DuckDBGetListChild(tableNamesValue, i); + if (!childValue.IsNull()) + { + tableNames[i] = NativeMethods.Value.DuckDBGetVarchar(childValue); + } + else + { + tableNames[i] = string.Empty; + } + } + + return tableNames; + } + catch (Exception ex) when (ex is DllNotFoundException or EntryPointNotFoundException or InvalidOperationException) + { + // If we fail to get table names due to missing DLL, missing entry point, or operation errors, + // just return null. This ensures backward compatibility - if the feature isn't available or fails, + // we just don't populate the table names. + // We don't log here to avoid noise in normal operation when the feature might not be available. + return null; + } + } + public override void Close() { if (closed) return; diff --git a/DuckDB.NET.Test/DuckDBDataReaderTests.cs b/DuckDB.NET.Test/DuckDBDataReaderTests.cs index ad3cdd66..11788f0d 100644 --- a/DuckDB.NET.Test/DuckDBDataReaderTests.cs +++ b/DuckDB.NET.Test/DuckDBDataReaderTests.cs @@ -433,4 +433,67 @@ public void ReadVarint() reader.Read(); var value = (BigInteger)reader.GetValue(0); } + + [Fact] + public void GetSchemaTableReturnsBaseTableName() + { + Command.CommandText = "CREATE TABLE test_table(id INTEGER, name VARCHAR);"; + Command.ExecuteNonQuery(); + + Command.CommandText = "INSERT INTO test_table VALUES (1, 'Alice'), (2, 'Bob');"; + Command.ExecuteNonQuery(); + + Command.CommandText = "SELECT id, name FROM test_table"; + using var reader = Command.ExecuteReader(); + + var schemaTable = reader.GetSchemaTable(); + + schemaTable.Should().NotBeNull(); + schemaTable!.Rows.Count.Should().Be(2); + + // Check that BaseTableName column exists + schemaTable.Columns.Should().Contain(c => c.ColumnName == "BaseTableName"); + + // Check that both columns have the same table name + schemaTable.Rows[0]["BaseTableName"].Should().Be("test_table"); + schemaTable.Rows[1]["BaseTableName"].Should().Be("test_table"); + + // Check that BaseColumnName is populated + schemaTable.Rows[0]["BaseColumnName"].Should().Be("id"); + schemaTable.Rows[1]["BaseColumnName"].Should().Be("name"); + } + + [Fact] + public void GetSchemaTableForJoinReturnsDbNull() + { + // Note: DuckDB's duckdb_get_table_names C API returns a list of unique table names + // referenced in the query, but does not provide per-column table name mapping. + // Therefore, for queries with multiple tables (joins), BaseTableName will be DBNull. + + Command.CommandText = "CREATE TABLE users(id INTEGER, name VARCHAR);"; + Command.ExecuteNonQuery(); + + Command.CommandText = "CREATE TABLE orders(order_id INTEGER, user_id INTEGER);"; + Command.ExecuteNonQuery(); + + Command.CommandText = "INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob');"; + Command.ExecuteNonQuery(); + + Command.CommandText = "INSERT INTO orders VALUES (100, 1), (200, 2);"; + Command.ExecuteNonQuery(); + + Command.CommandText = "SELECT u.id, u.name, o.order_id FROM users u JOIN orders o ON u.id = o.user_id"; + using var reader = Command.ExecuteReader(); + + var schemaTable = reader.GetSchemaTable(); + + schemaTable.Should().NotBeNull(); + schemaTable!.Rows.Count.Should().Be(3); + + // For join queries, BaseTableName should be DBNull since we can't determine + // which table each column comes from without additional API support + schemaTable.Rows[0]["BaseTableName"].Should().Be(DBNull.Value); + schemaTable.Rows[1]["BaseTableName"].Should().Be(DBNull.Value); + schemaTable.Rows[2]["BaseTableName"].Should().Be(DBNull.Value); + } }