Data Transform hack to apply Sklearn Feature Scaling on your custom Dataset class #5022
snknitin
started this conversation in
Show and tell
Replies: 1 comment 7 replies
-
I think this is cool, thanks for sharing. We could think about integrating this as part of a |
Beta Was this translation helpful? Give feedback.
7 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hey PyG Team,
I've been working on Heterogenous Graphs for a while and I have multiple edges with edge attributes in my
HeteroData()
. Since there is no data transform that does feature scaling, andStandardization and pre-processing is currently expected to be handled by the user
. Sometimes you might want your input embeddings or feature values to remain unchanged from raw files, and use transforms instead to Standardize or Normalize them based on use-case. Hardcoding the scaling in the Dataset process() method might pose a couple of issues, so I've built a custom transform that handles this.You can refer to this discussion from last month to get a better idea. I finally got some time to refine and post my approach for this.
Quick Context
Solution
In your process method when creating your own Dataset, just like you create
Add a
data['node1', 'to', 'node2'].edge_scaler = scaler
This way your
edge_store
will now have a scaler item, which won't be memory heavy. You can calculate this scaler within the Dataset class with a method like this -Here I'm reading all my data for graphs from .csv files. I choose specific colums of the files that are going to be that specifc edge's features and fit the scaler from all the raw files. You can choose any scaling approach from
sklearn.preprocessing
. Now you can collate and process the graph.Transform
Now when you want to load your data and apply the transform, you can do it this way. Note that you would ideally want to run scaling before the other transforms like
NormalizeFeatures
. This can be modified to work on Node features as well.Bonus Transform
If you're curious about the RevDelete(), that is another transform I made to delete the edge_label that forms on the reverse edges in HeteroData objects after the
ToUndirected()
transform. It's a simple code to avoid manual effortAcknowledgement
Hope this is helpful to someone. I would love to know what you guys think and if there is a better way to do this. Let me know if I can contribute to the package in some way. I'd be stoked, if these could be refined and adopted into the built-in methods !
Beta Was this translation helpful? Give feedback.
All reactions