Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions packages/backend/embedding_atlas/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ def determine_and_load_data(filename: str, splits: list[str] | None = None):
if filename.startswith(hf_prefix):
filename = filename.split(hf_prefix)[-1]

# Hugging Face data
if (len(filename.split("/")) <= 2) and (suffix == ""):
if Path(filename).is_dir():
df = load_huggingface_data(filename, splits)
elif (len(filename.split("/")) <= 2) and (suffix == ""):
Comment thread
bduhan marked this conversation as resolved.
Outdated
df = load_huggingface_data(filename, splits)
else:
df = load_pandas_data(filename)
Expand Down
14 changes: 12 additions & 2 deletions packages/backend/embedding_atlas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,24 @@ def load_pandas_data(url: str) -> pd.DataFrame:

def load_huggingface_data(filename: str, splits: list[str] | None) -> pd.DataFrame:
try:
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
except ImportError:
print(
"⚠️ Loading Hugging Face datasets requires the `datasets` package to be installed. Please run `pip install datasets`, then try again."
)
exit(-1)

ds: Any = load_dataset(filename)
if Path(filename).is_dir():
ds: Any = load_from_disk(filename)
else:
ds: Any = load_dataset(filename)

if not hasattr(ds, "keys"):
if splits is not None and len(splits) > 0:
raise ValueError(
"Cannot select splits for a single Hugging Face Dataset loaded from disk."
)
return ds.to_pandas()

if splits is None or len(splits) == 0:
ds_split_options = []
Expand Down
Loading