-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper_augmentations.py
30 lines (26 loc) · 1020 Bytes
/
helper_augmentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import random
from random import randint
from random import shuffle
import torchvision
class SwapReferenceTest(object):
def __call__(self, sample):
prob = random.random()
# Half chance to swap reference and test
if prob > 0.5:
trf_reference = sample['reference']
trf_test = sample['test']
else:
trf_reference = sample['test']
trf_test = sample['reference']
return trf_reference, trf_test
class JitterGamma(object):
def __call__(self, sample):
prob = random.random()
trf_reference = sample['reference']
trf_test = sample['test']
# Half chance to swap reference and test
if prob > 0.5:
gamma = random.random() + 0.1
trf_reference = torchvision.transforms.functional.adjust_gamma(trf_reference, gamma)
trf_test = torchvision.transforms.functional.adjust_gamma(trf_test, gamma)
return trf_reference, trf_test