-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathutils.py
executable file
·145 lines (118 loc) · 4.45 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import sys
from torch.optim.lr_scheduler import _LRScheduler
# Function for Computing the Precision
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
# Set the maximum value of k
maxk = max(topk)
# Determine batch size
batch_size = target.size(0)
# Get predictions with top values
_, pred = output.topk(maxk, 1, True, True)
# Reshape the output
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
# Calculate the number of correct predictions
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
# Calculate the accuracy rate
res.append(correct_k.mul_(100.0 / batch_size))
return res
# Function for Computing and Storing the Average Value
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# Function for Storing the Progress of the Model
class ProgressMeter(object):
def __init__(self, num_batches, meters, logger = None, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
self.logger = logger
def display(self, batch):
# Assign formatted prefix statement to the entries list
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
def compose_json(self):
# Empty Dictionary
best_res = {}
for meter in self.meters:
# Update the Dictionary
best_res[meter.name] = meter.avg
return best_res
def display_avg(self):
entries = [self.prefix]
# Assign formatted prefix statement to the entries list
entries += [f"{meter.name}:{meter.avg:6.3f}" for meter in self.meters]
def _get_batch_fmtstr(self, num_batches):
# Get the length of string format
num_digits = len(str(num_batches // 1))
# Assign formatted value to the entries list
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
# Function for the Warmup Learning Rate
class WarmUpLR(_LRScheduler):
"""warmup_training learning rate scheduler
Args:
optimizer: optimzier(e.g. SGD)
total_iters: totoal_iters of warmup phase
"""
def __init__(self, optimizer, total_iters, last_epoch=-1):
# Set Total Iterations and Variables from _LRScheduler
self.total_iters = total_iters
super().__init__(optimizer, last_epoch)
def get_lr(self):
"""we will use the first m batches, and set the learning
rate to base_lr * m / total_iters
"""
# Calculate Learning Rate
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
# Function for the Logger
class Logger(object):
def __init__(self, filename="Default.log"):
# Set a Terminal Object
self.terminal = sys.stdout
self.log = open(filename, "a")
def write(self, message):
# Write Message to Terminal
self.terminal.write(message)
# Write Message to Log
self.log.write(message)
def flush(self):
pass
# Function for Walking through All the Files
def walkFile(file):
count = 0
for root, dirs, files in os.walk(file):
for f in files:
count += 1
# Determine file size
file_size = os.path.getsize(os.path.join(root, f))
if file_size == 0:
# Return a value of 0 if model encounters a zero-length file
return 0
return count
# Function for Checking the File Number
def check_file_number(dir_to_check, domainbed_dataset):
dic = {'PACS':41, 'VLCS': 44, 'OfficeHome': 64, 'TerraIncognita': 97, 'DomainNet': 2295}
# Define a variable for the file count
num = walkFile(dir_to_check)
# Compare the number of files with the expected value
return (num == dic[domainbed_dataset])