-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathusu05a.py
144 lines (110 loc) · 3.99 KB
/
usu05a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# USU OS Python Assignment 5a
#
# NDVI - Normalized Difference Vegetation Index
import gdal
from gdalconst import *
import numpy as np
import numpy.ma as ma
import time
gdal.UseExceptions()
SRC_RASTERFILE = 'data/usu04/aster.img'
DST_RASTERFILE = 'data/usu04/aster-ndvi.img'
DST_DATA_TYPE = np.float32
NO_DATA_VALUE = -99
# Register the raster driver and open the source data.
gdal.AllRegister()
srcDS = gdal.Open(SRC_RASTERFILE, GA_ReadOnly)
if srcDS is None:
print('Can''t open source raster file ' + SRC_RASTERFILE)
exit(1)
cols = srcDS.RasterXSize
rows = srcDS.RasterYSize
src_bands = srcDS.RasterCount
print('\nRaster file: ' + SRC_RASTERFILE)
print('\nRows x Columns x Bands: {0:d} x {1:d} x {2:d}'.format(rows, cols, src_bands))
projection = srcDS.GetProjection()
geotransform = srcDS.GetGeoTransform()
x0, dx, rx, y0, ry, dy = geotransform
print('\nTop-Left corner (x,y): {0:12.4f}, {1:12.4f}'.format(x0, y0))
print('Pixel resolution (x,y): {0:12.4f}, {1:12.4f}'.format(dx, dy))
print('Axis rotation (x,y): {0:12.4f}, {1:12.4f}'.format(rx, ry))
print()
for n in range(1, src_bands + 1):
srcBand = srcDS.GetRasterBand(n)
xblock, yblock = srcBand.GetBlockSize()
datatype = gdal.GetDataTypeName(srcBand.DataType)
color_interp = gdal.GetColorInterpretationName(srcBand.GetColorInterpretation())
print('Band {0}: Block={1}x{2} Type={3} ColorInterp={4}'.format(
n, xblock, yblock, datatype, color_interp))
# Open the output raster file.
driver = srcDS.GetDriver()
dstDS = driver.Create(DST_RASTERFILE, cols, rows, 1, GDT_Float32)
if dstDS is None:
print('Can''t open destination raster file ' + DST_RASTERFILE)
exit(1)
dstBand = dstDS.GetRasterBand(1)
startTime = time.time()
# Initialize an array for the data.
data = [None for n in range(src_bands + 1)]
valid_count = 0
ndvi_sum = 0.0
no_valid = some_valid = all_valid = 0
# Read all bands a block at a time.
for yoff in range(0, rows, yblock):
for xoff in range(0, cols, xblock):
xsize = min(cols - xoff, xblock)
ysize = min(rows - yoff, yblock)
block_size = xsize * ysize
# Read blocks from bands 2 & 3.
for n in (2,3):
srcBand = srcDS.GetRasterBand(n)
data[n] = srcBand.ReadAsArray(xoff, yoff, xsize, ysize).astype(DST_DATA_TYPE)
nir = data[3]
red = data[2]
# NDVI denominator.
b = nir + red
# Mask for divide by zero.
mask = np.equal(b, 0)
masked = np.count_nonzero(mask)
unmasked = block_size - masked
# Calculate ndvi = (nir - red) / (nir + red)
if unmasked == 0:
# No valid data in this block.
ndvi = np.full((ysize, xsize), NO_DATA_VALUE)
no_valid += 1
elif masked > 0:
# Some valid data, use a masked array.
a = ma.array(nir-red, mask=mask, fill_value=NO_DATA_VALUE)
ndvi = a/b
valid_count += unmasked
ndvi_sum += ndvi.sum()
ndvi = ndvi.filled()
some_valid += 1
else:
# All data in the block is valid.
a = nir-red
ndvi = a/b
valid_count += block_size
ndvi_sum += ndvi.sum()
all_valid += 1
dstBand.WriteArray(ndvi, xoff, yoff)
srcBand = data = None
# Finish up with the output raster.
dstBand.SetNoDataValue(NO_DATA_VALUE)
dstBand.FlushCache()
dstBand.GetStatistics(0, 1)
dstDS.SetGeoTransform(geotransform)
dstDS.SetProjection(projection)
gdal.SetConfigOption('HFA_USE_RRD', 'YES')
dstDS.BuildOverviews(overviewlist=[2, 4, 8, 16, 32, 64, 128])
pixel_count = rows * cols
print('\nValid blocks: NONE={0} SOME={1} ALL={2}'.format(no_valid, some_valid, all_valid))
print('\nValid Pixels/Total Pixels: {0:,d}/{1:,d} ({2:.0f}%)'.format(
valid_count, pixel_count, 100.0 * valid_count / pixel_count
))
print('\nNDVI mean: {0:.3f}'.format(ndvi_sum / valid_count))
endTime = time.time()
print('\ntime: {0:.3f} sec'.format(endTime - startTime))
# Done.
srcDS = None
exit(0)