#!/usr/bin/python
# -*- coding: utf-8 -*-
import osr
import gdal
import numpy
from PIL import Image
from StringIO import StringIO
def reproject_vrt(src_epsg, src_gt, src_size, dst_epsg):
src_proj = osr.SpatialReference()
src_proj.ImportFromEPSG(src_epsg)
driver = gdal.GetDriverByName('VRT')
src_ds = driver.Create('', src_size[0], src_size[1])
src_ds.SetGeoTransform(src_gt)
src_ds.SetProjection(src_proj.ExportToWkt())
dst_proj = osr.SpatialReference()
dst_proj.ImportFromEPSG(dst_epsg)
dst_ds = gdal.AutoCreateWarpedVRT(
src_ds,
src_proj.ExportToWkt(),
dst_proj.ExportToWkt())
return dst_ds.GetGeoTransform(), dst_ds.RasterXSize, dst_ds.RasterYSize
def reproject_band(band, src_epsg, src_gt, dst_epsg, dst_gt, dst_width, dst_height):
src_height, src_width = band.shape
driver = gdal.GetDriverByName('MEM')
src_ds = driver.Create('', src_width, src_height)
src_ds.GetRasterBand(1).WriteArray(band)
src_proj = osr.SpatialReference()
src_proj.ImportFromEPSG(src_epsg)
src_ds.SetGeoTransform(src_gt)
src_ds.SetProjection(src_proj.ExportToWkt())
dst_ds = driver.Create('', dst_width, dst_height)
dst_proj = osr.SpatialReference()
dst_proj.ImportFromEPSG(dst_epsg)
dst_ds.SetGeoTransform(dst_gt)
dst_ds.SetProjection(dst_proj.ExportToWkt())
gdal.ReprojectImage(src_ds, dst_ds, None, None, gdal.GRA_Bilinear)
return dst_ds.ReadAsArray()
def reproject_raster(raster, src_epsg, src_gt, dst_epsg):
src_img = Image.open(StringIO(raster))
dst_gt, dst_width, dst_height = reproject_vrt(src_epsg, src_gt, src_img.size, dst_epsg)
src_array = numpy.array(src_img)
if src_array.ndim == 3:
dst_array_shape = (dst_height, dst_width, src_array.shape[2])
dst_array = numpy.zeros(dst_array_shape, src_array.dtype)
for i in range(0, src_array.shape[2]):
band = src_array[:, :, i]
band = reproject_band(band, src_epsg, src_gt, dst_epsg, dst_gt, dst_width, dst_height)
dst_array[:, :, i] = band
elif src_array.ndim == 2:
dst_array = reproject_band(src_array, src_epsg, src_gt, dst_epsg, dst_gt, dst_width, dst_height)
else:
raise Exception('Unexpected array geometry!')
dst_img = Image.frombytes(src_img.mode, (dst_width, dst_height), dst_array.data)
raster = StringIO()
dst_img.save(raster, format=src_img.format)
return raster.getvalue()