Added the ndcg metric [WIP]#2632
Conversation
There was a problem hiding this comment.
@kamalojasv181 thanks for the PR. I left few comments on few points, but I haven't yet explored the code to compute ndcg.
Can you please write few tests vs scikit-learn as ref ?
vfdev-5
left a comment
There was a problem hiding this comment.
Thanks a lot for the update @kamalojasv181 !
I left few other comments on the implementation and the API.
Let's also start working on docs and tests
|
Thanks for all the feedback. I will revert with a pull request addressing everything! |
You can just continue working with this pull request, no need to revert anything. |
vfdev-5
left a comment
There was a problem hiding this comment.
Thanks for the updates @kamalojasv181 !
I have few other code update suggestions.
|
@vfdev-5 there are a bunch of things I did in this commit:
If there is anything else, lemme know before I can finally add some comments against the class and documentation. |
| discounted_gains = torch.tensor( | ||
| [_tie_averaged_dcg(y_p, y_t, discount_cumsum, device) for y_p, y_t in zip(y_pred, y_true)], device=device | ||
| ) |
There was a problem hiding this comment.
So, there is no way to make it vectorized == without for-loop ?
There was a problem hiding this comment.
I havent checked yet. For now I have added this implementation. It's a TODO
|
Belows are checklist for test in ddp configuration
The whole process should be seem like (Or you can refer from ignite/tests/ignite/metrics/test_accuracy.py) ignite/tests/ignite/metrics/test_accuracy.py Line 488 in 26f7cec acc = Accuracy(is_multilabel=True, device=metric_device)
# data generation
torch.manual_seed(12 + rank)
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)
# update each batch
def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size, :],
y_true[i * batch_size : (i + 1) * batch_size],
)
# Initialize Engine
engine = Engine(update)
acc = Accuracy(device=metric_device)
acc.attach(engine, "acc")
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)
# all gather data
y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)
res = engine.state.metrics["acc"]
# calculate reference value with scikit learn and compare
true_res = sklearn.metrics.accuracy_score(y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy())
assert pytest.approx(res) == true_res |
…ne multiomial distribution
| return ( | ||
| [v for v in y_preds[i * batch_size : (i + 1) * batch_size, ...]], | ||
| [v for v in y_true[i * batch_size : (i + 1) * batch_size]], | ||
| ) |
There was a problem hiding this comment.
@kamalojasv181 Why do you return tuple of 2 lists instead of a tuple of two tensors ?
There was a problem hiding this comment.
I have taken inspiration from the code provided by @puhuk . Here each element of the list is a batch. We feed our engine one batch at a time; hence using a list is also ok. To maintain uniformity across the code, I have kept it this way.
There was a problem hiding this comment.
Oh, I see, he provided a wrong link. Yes, in accuracy we also a test on list of tensors and numbers but this is untypical. Here is a typical example
ignite/tests/ignite/metrics/test_accuracy.py
Lines 412 to 463 in 26f7cec
|
@sadra-barikbin can you check why |
|
@kamalojasv181 Hi, do you need any help to finalize this PR? Please feel free to let me and @vfdev-5 know :) |
|
Any updates on this ? If it is not finished yet, I'd love to contribute @vfdev-5 |
|
Yes, this PR is not finished, unfortunately. @ili0820 if you can help with getting it landed it would be great! |
@ili0820 yes, it remains few things here:
In case if you are not familiar with DDP, please check: https://pytorch-ignite.ai/tutorials/advanced/01-collective-communication/. As for testing best practices, we would like to use now ignite/tests/ignite/metrics/test_ssim.py Line 226 in f2b1183 if you have any questions, you can reach out to us on Discord in #start-contributing channel. |
|
Closing in favor of #3608 |
Related #2631
Description: This is the implementation [WIP] for the NDCG metric.
Check list: