Skip to content

Fix batch classification TypeError when using data_path (#611)#616

Merged
zhmiao merged 1 commit into
microsoft:mainfrom
jQuinRivero:fix/batch-classification-imagefolder-error
May 24, 2026
Merged

Fix batch classification TypeError when using data_path (#611)#616
zhmiao merged 1 commit into
microsoft:mainfrom
jQuinRivero:fix/batch-classification-imagefolder-error

Conversation

@jQuinRivero
Copy link
Copy Markdown
Contributor

Fixes #611

Problem

Calling batch_image_classification(data_path=...) on any TIMM-based or ResNet-based classifier raises a TypeError because pw_data.ImageFolder does not accept a path_head keyword argument:

# Both timm_base and resnet_base had this:
dataset = pw_data.ImageFolder(
    data_path,
    transform=self.transform,
    path_head='.'        # ← ImageFolder doesn't accept this
)

Additionally, ImageFolder is the abstract base class whose __getitem__ returns None, so even if the kwarg issue were fixed, the dataloader would fail to unpack (img, img_path).

Fix

  • Replace pw_data.ImageFolder with pw_data.ClassificationImageFolder — the correct subclass that implements __getitem__ returning (img, img_path).
  • Remove the invalid path_head='.' argument (ClassificationImageFolder builds full paths via os.walk, so path_head is unnecessary).
  • Add ClassificationImageFolder to __all__ in datasets.py for export consistency.

Files changed

File Change
PytorchWildlife/data/datasets.py Added ClassificationImageFolder to __all__
PytorchWildlife/models/classification/timm_base/base_classifier.py ImageFolderClassificationImageFolder, removed path_head
PytorchWildlife/models/classification/resnet_base/base_classifier.py Same fix

Safety

  • The det_results code path (using DetectionCrops) is untouchedpath_head is valid there.
  • No other file in the repo references pw_data.ImageFolder.
  • All detectors already use DetectionImageFolder correctly.
  • The broken code path could never have worked, so this is purely additive — no existing working behavior changes.

…\n\nReplace pw_data.ImageFolder with pw_data.ClassificationImageFolder in\nboth timm_base and resnet_base classifiers, and remove the invalid\npath_head keyword argument that ImageFolder does not accept.\n\nImageFolder.__getitem__ is abstract (returns None), so the correct\nsubclass for classification is ClassificationImageFolder, which returns\nthe (img, img_path) tuple the dataloader loop expects.\n\nAlso export ClassificationImageFolder from datasets __all__."
Copy link
Copy Markdown
Collaborator

@zhmiao zhmiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@zhmiao
Copy link
Copy Markdown
Collaborator

zhmiao commented May 24, 2026

Thank you so much @jQuinRivero ! I have merged it.

@zhmiao zhmiao merged commit 1cd46cf into microsoft:main May 24, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Batch classification fails when using data_path in PytorchWildlife/models/classification/timm_base/base_classifier.py

2 participants