Best practice: where to count number of samples per class #15199
Unanswered
mfoglio
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, I am looking for best practices as I know that there could be multiple ways to solve the problem. I just like to understand if there is a lightning approach that should be preferred.
In my
LightningModuleI initialize aCrossEntropyLosswith specificweightto handle imbalanced classes:torch.nn.CrossEntropyLoss(weight=my_weights). The weight for each class is defined as the1 / number_of_samples_in_the_class.In order to do this, I need to supply my
LightningModuleinstance with the number of samples per class. However, usually you would load the data (and therefore count the number of samples per class in thesetupfunction of theLightningDataModuleinstance. So here's the problem: usually, when you initialize theLightningModuleyou haven't loaded yet the data.Example:
As possible solutions, I could manually call
my_data_module.setup()or I could compute the number of samples inside the__init__function ofMyDataModulebut both ways seem not to follow torch lightning philosophy. What would be the cleanest way to solve this?Thank you!
Beta Was this translation helpful? Give feedback.
All reactions