仅个人记录:shp裁剪tif;shp裁剪shp;矢量转栅格;多个shp裁剪shp;栅格边界矢量化。汇总:输入shp和影像,输出影像对应的标签(栅格边界矢量化,shp裁剪shp,shp转tif)_运用shp文件裁剪tif-程序员宅基地

技术标签: 原型模式  

shp裁剪tif

# -*- coding: utf-8 -*-
import os
import numpy as np
from osgeo import gdal, gdalnumeric, ogr, osr, gdal_array
gdal.UseExceptions()

def world2Pixel(geoMatrix, x, y):
  """
  Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
  the pixel location of a geospatial coordinate
  """
  ulX = geoMatrix[0]
  ulY = geoMatrix[3]
  xDist = geoMatrix[1]
  yDist = geoMatrix[5]
  rtnX = geoMatrix[2]
  rtnY = geoMatrix[4]
  pixel = int((x - ulX) / xDist)
  line = int((ulY - y) / xDist)
  return (pixel, line)

#
#  EDIT: this is basically an overloaded
#  version of the gdal_array.OpenArray passing in xoff, yoff explicitly
#  so we can pass these params off to CopyDatasetInfo
#
def OpenArray( array, prototype_ds = None, xoff=0, yoff=0 ):
    # ds = gdal.Open( gdalnumeric.GetArrayFilename(array))
    ds = gdal_array.OpenArray(array)

    if ds is not None and prototype_ds is not None:
        if type(prototype_ds).__name__ == 'str':
            prototype_ds = gdal.Open( prototype_ds )
        if prototype_ds is not None:
            gdalnumeric.CopyDatasetInfo( prototype_ds, ds, xoff=xoff, yoff=yoff )
    return ds


def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def shpClipRaster(shapefile_path, raster_path, save_path):
    # Load the source data as a gdalnumeric array
    # srcArray = gdalnumeric.LoadFile(raster_path)

    # Also load as a gdal image to get geotransform
    # (world file) info
    srcImage = gdal.Open(raster_path)
    geoTrans = srcImage.GetGeoTransform()
    geoProj = srcImage.GetProjection()

    # Create an OGR layer from a boundary shapefile
    shapef = ogr.Open(shapefile_path)
    lyr = shapef.GetLayer( os.path.split( os.path.splitext( shapefile_path )[0] )[1] )
    poly = lyr.GetNextFeature()

    # Convert the layer extent to image pixel coordinates
    minX, maxX, minY, maxY = lyr.GetExtent()
    ulX, ulY = world2Pixel(geoTrans, minX, maxY)
    lrX, lrY = world2Pixel(geoTrans, maxX, minY)

    # Calculate the pixel size of the new image
    pxWidth = int(lrX - ulX)
    pxHeight = int(lrY - ulY)

    # clip = srcArray[:, ulY:lrY, ulX:lrX]
    clip = srcImage.ReadAsArray(ulX,ulY,pxWidth,pxHeight)   #***只读要的那块***

    #
    # EDIT: create pixel offset to pass to new image Projection info
    #
    xoffset =  ulX
    yoffset =  ulY
    print ("Xoffset, Yoffset = ( %f, %f )" % ( xoffset, yoffset ))

    # Create a new geomatrix for the image
    geoTrans = list(geoTrans)
    geoTrans[0] = minX
    geoTrans[3] = maxY

    write_img(save_path, geoProj, geoTrans, clip)
    gdal.ErrorReset()

if __name__ == "__main__":
    shp = "dataset/E22_Bound.shp"
    img = "dataset/CGdomYRJ-114(CK0-17)_E_22.tif"
    out = "dataset/E22.tif"

    shpClipRaster(shp,img,out)
    print(img)

shp裁剪shp

import os
from osgeo import gdal, ogr

def ShapeClip(
		baseFilePath,
		maskFilePath,
		saveFolderPath):
	"""
	矢量裁剪
	:param baseFilePath: 要裁剪的矢量文件
	:param maskFilePath: 掩膜矢量文件
	:param saveFolderPath: 裁剪后的矢量文件保存目录
	:return:
	"""
	ogr.RegisterAll()
	gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
	# 载入要裁剪的矢量文件

	baseData = ogr.Open(baseFilePath)
	print(os.path.split( os.path.splitext( baseFilePath )[0] )[1])
	baseLayer = baseData.GetLayer( os.path.split( os.path.splitext( baseFilePath )[0] )[1] )

	spatial = baseLayer.GetSpatialRef()
	geomType = baseLayer.GetGeomType()
	baseLayerName = baseLayer.GetName()
	# 载入掩膜矢量文件
	maskData = ogr.Open(maskFilePath)
	maskLayer = maskData.GetLayer()
	maskLayerName = maskLayer.GetName()
	# 生成裁剪后的矢量文件
	outLayerName = maskLayerName + "_Clip_" + baseLayerName
	outFilePath = saveFolderPath
	gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
	driver = ogr.GetDriverByName("ESRI Shapefile")
	outData = driver.CreateDataSource(outFilePath)
	outLayer = outData.CreateLayer(outLayerName, spatial, geomType)
	baseLayer.Clip(maskLayer, outLayer)
	outData.Release()
	baseData.Release()
	maskData.Release()
	return outFilePath


if __name__ == "__main__":
	baseFilePath = 'dataset/veg_E_22.shp'
	maskFilePath = 'dataset/E22_Bound.shp'
	saveFolderPath = 'dataset/E22.shp'
	outFilePath=ShapeClip(baseFilePath,maskFilePath,saveFolderPath)
	print(outFilePath)

矢量转栅格

from osgeo import gdal, ogr, gdalconst
def shp2Raster(shp,templatePic,output,nodata):
    """
    shp:字符串,一个矢量,从0开始计数,整数
    templatePic:字符串,模板栅格,一个tif,地理变换信息从这里读,栅格大小与该栅格一致
    output:字符串,输出栅格,一个tif
    field:字符串,栅格值的字段
    nodata:整型或浮点型,矢量空白区转换后的值
    """
    ndsm = templatePic
    data = gdal.Open(ndsm, gdalconst.GA_ReadOnly)
    geo_transform = data.GetGeoTransform()
    proj=data.GetProjection()
    #source_layer = data.GetLayer()
    x_min = geo_transform[0]
    y_max = geo_transform[3]
    x_max = x_min + geo_transform[1] * data.RasterXSize
    y_min = y_max + geo_transform[5] * data.RasterYSize
    x_res = data.RasterXSize
    y_res = data.RasterYSize
    mb_v = ogr.Open(shp)
    mb_l = mb_v.GetLayer()
    pixel_width = geo_transform[1]
    #输出影像为24位整型
    target_ds = gdal.GetDriverByName('GTiff').Create(output, x_res, y_res, 1, gdal.GPI_RGB)

    target_ds.SetGeoTransform(geo_transform)
    target_ds.SetProjection(proj)
    band = target_ds.GetRasterBand(1)
    NoData_value = nodata
    band.SetNoDataValue(NoData_value)
    band.FlushCache()
    gdal.RasterizeLayer(target_ds, [1], mb_l, options=['ALL_TOUCHED=TRUE'])

    target_ds = None

if __name__ == "__main__":
    shp = "dataset/E22.shp"
    templatePic= "dataset/E22.tif"
    output = "dataset/E22_mask.tif"
    nodata=0
    shp2Raster(shp,templatePic,output,nodata)
    

多个shp裁剪shp

import os
import os
import numpy as np
from osgeo import gdal, gdalnumeric, ogr, osr, gdal_array
gdal.UseExceptions()

def world2Pixel(geoMatrix, x, y):
  """
  Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
  the pixel location of a geospatial coordinate
  """
  ulX = geoMatrix[0]
  ulY = geoMatrix[3]
  xDist = geoMatrix[1]
  yDist = geoMatrix[5]
  rtnX = geoMatrix[2]
  rtnY = geoMatrix[4]
  pixel = int((x - ulX) / xDist)
  line = int((ulY - y) / xDist)
  return (pixel, line)

#
#  EDIT: this is basically an overloaded
#  version of the gdal_array.OpenArray passing in xoff, yoff explicitly
#  so we can pass these params off to CopyDatasetInfo
#
def OpenArray( array, prototype_ds = None, xoff=0, yoff=0 ):
    # ds = gdal.Open( gdalnumeric.GetArrayFilename(array))
    ds = gdal_array.OpenArray(array)

    if ds is not None and prototype_ds is not None:
        if type(prototype_ds).__name__ == 'str':
            prototype_ds = gdal.Open( prototype_ds )
        if prototype_ds is not None:
            gdalnumeric.CopyDatasetInfo( prototype_ds, ds, xoff=xoff, yoff=yoff )
    return ds


def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

pre_path='dataset/pre/'
labellist = filter(lambda x: x.find('label')!=-1, os.listdir(pre_path))
list1 = list(map(lambda x: x[:], labellist))
label_name=pre_path +  list1[0]

boundarylist = filter(lambda x: x.find('shp')!=-1, os.listdir(pre_path+'boundary/'))
list2 = list(map(lambda x: x[:], boundarylist))


imagelist = filter(lambda x: x.find('tif')!=-1, os.listdir(pre_path))
list3 = list(map(lambda x: x[:], imagelist))
img_path=pre_path +  list3[0]


"""
矢量裁剪
:param label_name: 要裁剪的矢量文件
:param boundary_name: 掩膜矢量文件
img_path: 影像
:param saveFolderPath: 裁剪后的矢量文件保存目录
:return:
"""
ogr.RegisterAll()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
# 载入要裁剪的矢量文件

labelData = ogr.Open(label_name)

labelLayer = labelData.GetLayer( os.path.split( os.path.splitext( label_name )[0] )[1] )

spatial = labelLayer.GetSpatialRef()
geomType = labelLayer.GetGeomType()


# 载入掩膜矢量文件

def new_func(outLayerName):
    return outLayerName

for i in list2:
    boundary_name=pre_path+'boundary/'+ i
    maskData = ogr.Open(boundary_name)
    maskLayer = maskData.GetLayer()
    #裁剪shp
    # 生成裁剪后的矢量文件
    save_shp_dir='./dataset/pre/shp/'
    if not os.path.exists(save_shp_dir):
        os.mkdir(save_shp_dir)
    outLayerName = (save_shp_dir+i)
    gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
    driver = ogr.GetDriverByName("ESRI Shapefile")
    outData = driver.CreateDataSource(outLayerName)
    outLayer = outData.CreateLayer(new_func(outLayerName), spatial, geomType)
    labelLayer.Clip(maskLayer, outLayer)
    outData.Release()
    maskData.Release()

    #裁剪tif

    shp = "dataset/E22_Bound.shp"
    img = "dataset/CGdomYRJ-114(CK0-17)_E_22.tif"
    out = "dataset/E22.tif"
    # Load the source data as a gdalnumeric array
    # srcArray = gdalnumeric.LoadFile(raster_path)

    # Also load as a gdal image to get geotransform
    # (world file) info
    srcImage = gdal.Open(raster_path)
    geoTrans = srcImage.GetGeoTransform()
    geoProj = srcImage.GetProjection()

    # Create an OGR layer from a boundary shapefile
    shapef = ogr.Open(shapefile_path)
    lyr = shapef.GetLayer( os.path.split( os.path.splitext( shapefile_path )[0] )[1] )
    poly = lyr.GetNextFeature()

    # Convert the layer extent to image pixel coordinates
    minX, maxX, minY, maxY = lyr.GetExtent()
    ulX, ulY = world2Pixel(geoTrans, minX, maxY)
    lrX, lrY = world2Pixel(geoTrans, maxX, minY)

    # Calculate the pixel size of the new image
    pxWidth = int(lrX - ulX)
    pxHeight = int(lrY - ulY)

    # clip = srcArray[:, ulY:lrY, ulX:lrX]
    clip = srcImage.ReadAsArray(ulX,ulY,pxWidth,pxHeight)   #***只读要的那块***

    #
    # EDIT: create pixel offset to pass to new image Projection info
    #
    xoffset =  ulX
    yoffset =  ulY
    print ("Xoffset, Yoffset = ( %f, %f )" % ( xoffset, yoffset ))

    # Create a new geomatrix for the image
    geoTrans = list(geoTrans)
    geoTrans[0] = minX
    geoTrans[3] = maxY

    write_img(save_path, geoProj, geoTrans, clip)
    gdal.ErrorReset()
labelData.Release()

汇总:输入shp和影像,输出影像对应的标签

#影像裁剪shp,转为栅格,为该影像标签
#输入:存放影像文件夹dataset/sat_train,存放标签矢量文件夹dataset/mask_shp
#输出:标签(栅格),存放在dataset/mask_train

from osgeo import gdal, ogr, osr, gdalconst
import fnmatch
import os

def ShapeClip(
		baseFilePath,
		maskFilePath,
		saveFolderPath):
	"""
	矢量裁剪
	:param baseFilePath: 要裁剪的矢量文件
	:param maskFilePath: 掩膜矢量文件
	:param saveFolderPath: 裁剪后的矢量文件保存目录
	:return:
	"""
	ogr.RegisterAll()
	gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
	# 载入要裁剪的矢量文件

	baseData = ogr.Open(baseFilePath)
	baseLayer = baseData.GetLayer( os.path.split( os.path.splitext( baseFilePath )[0] )[1] )

	spatial = baseLayer.GetSpatialRef()
	geomType = baseLayer.GetGeomType()
	baseLayerName = baseLayer.GetName()
	# 载入掩膜矢量文件
	maskData = ogr.Open(maskFilePath)
	maskLayer = maskData.GetLayer()
	maskLayerName = maskLayer.GetName()
	# 生成裁剪后的矢量文件
	outLayerName = maskLayerName + "_Clip_" + baseLayerName
	outFilePath = saveFolderPath
	gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
	driver = ogr.GetDriverByName("ESRI Shapefile")
	outData = driver.CreateDataSource(outFilePath)
	outLayer = outData.CreateLayer(outLayerName, spatial, geomType)
	baseLayer.Clip(maskLayer, outLayer)
	outData.Release()
	baseData.Release()
	maskData.Release()
	return outFilePath

def shp2Raster(shp,templatePic,output,nodata):
    """
    shp:字符串,一个矢量,从0开始计数,整数
    templatePic:字符串,模板栅格,一个tif,地理变换信息从这里读,栅格大小与该栅格一致
    output:字符串,输出栅格,一个tif
    field:字符串,栅格值的字段
    nodata:整型或浮点型,矢量空白区转换后的值
    """
    ndsm = templatePic
    data = gdal.Open(ndsm, gdalconst.GA_ReadOnly)
    geo_transform = data.GetGeoTransform()
    proj=data.GetProjection()
    #source_layer = data.GetLayer()
    x_min = geo_transform[0]
    y_max = geo_transform[3]
    x_max = x_min + geo_transform[1] * data.RasterXSize
    y_min = y_max + geo_transform[5] * data.RasterYSize
    x_res = data.RasterXSize
    y_res = data.RasterYSize
    mb_v = ogr.Open(shp)
    mb_l = mb_v.GetLayer()
    pixel_width = geo_transform[1]
    #输出影像为24位整型
    target_ds = gdal.GetDriverByName('GTiff').Create(output, x_res, y_res, 1, gdal.GPI_RGB)

    target_ds.SetGeoTransform(geo_transform)
    target_ds.SetProjection(proj)
    band = target_ds.GetRasterBand(1)
    NoData_value = nodata
    band.SetNoDataValue(NoData_value)
    band.FlushCache()
    gdal.RasterizeLayer(target_ds, [1], mb_l, options=['ALL_TOUCHED=TRUE'])

    target_ds = None

print("开始制作标签")
ogr.RegisterAll()
img_path="dataset/sat_train/" #影像所在的文件夹
mask_shp_path="dataset/mask_shp/" #原始标签shp位置

shape_path="dataset/mask_boundary_shp/" #shape输出位置
mask_clip_path='dataset/mask_clip_train/'#裁剪后shp
mask_train_path='dataset/mask_train/'#最终输出标签文件夹
if not os.path.exists(shape_path):
    os.mkdir(shape_path)
if not os.path.exists(mask_clip_path):
    os.mkdir(mask_clip_path)

imagelist = filter(lambda x: x.find('shp')!=-1, os.listdir(mask_shp_path))
list = list(map(lambda x: x[:], imagelist))
mask_shp_name=mask_shp_path +  list[0]
img_list = fnmatch.filter(os.listdir(img_path), '*.tif')
for img in img_list:
    p_img=img_path+img
    outfilename = shape_path+img[:-4]+".shp"
    dataset = gdal.Open(p_img)
    oDriver = ogr.GetDriverByName('ESRI Shapefile')
    oDS = oDriver.CreateDataSource(outfilename)
    srs = osr.SpatialReference(wkt=dataset.GetProjection())
    geocd = dataset.GetGeoTransform()
    oLayer = oDS.CreateLayer("polygon", srs, ogr.wkbPolygon)
    oDefn = oLayer.GetLayerDefn()
    row = dataset.RasterXSize
    line = dataset.RasterYSize
    geoxmin = geocd[0]
    geoymin = geocd[3]
    geoxmax = geocd[0] + (row) * geocd[1] + (line) * geocd[2]
    geoymax = geocd[3] + (row) * geocd[4] + (line) * geocd[5]
    ring = ogr.Geometry(ogr.wkbLinearRing)
    ring.AddPoint(geoxmin, geoymin)
    ring.AddPoint(geoxmax, geoymin)
    ring.AddPoint(geoxmax, geoymax)
    ring.AddPoint(geoxmin, geoymax)
    ring.CloseRings()
    poly = ogr.Geometry(ogr.wkbPolygon)
    poly.AddGeometry(ring)
    outfeat = ogr.Feature(oDefn)
    outfeat.SetGeometry(poly)
    oLayer.CreateFeature(outfeat)
    outfeat = None
    oDS.Destroy()
    mask_train_name = mask_clip_path+img[:-4]+".shp"
    #裁剪
    outFilePath=ShapeClip(mask_shp_name,outfilename, mask_train_name)
    #矢量转栅格
    output = mask_train_path + img
    nodata=0
    shp2Raster(mask_train_name,p_img,output,nodata)
    print(output)
    
print('标签制作完成')

 做mask

'根据多个给定范围shp,对画好的标签进行裁剪并转栅格,做为标签样本,对影像进行裁剪,作为影像样本'
'输入:'
'输出'
import os
import os
import numpy as np
from osgeo import gdal, gdalnumeric, ogr, osr, gdal_array
gdal.UseExceptions()

def world2Pixel(geoMatrix, x, y):
  """
  Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
  the pixel location of a geospatial coordinate
  """
  ulX = geoMatrix[0]
  ulY = geoMatrix[3]
  xDist = geoMatrix[1]
  yDist = geoMatrix[5]
  rtnX = geoMatrix[2]
  rtnY = geoMatrix[4]
  pixel = int((x - ulX) / xDist)
  line = int((ulY - y) / xDist)
  return (pixel, line)

#
#  EDIT: this is basically an overloaded
#  version of the gdal_array.OpenArray passing in xoff, yoff explicitly
#  so we can pass these params off to CopyDatasetInfo
#
def OpenArray( array, prototype_ds = None, xoff=0, yoff=0 ):
    # ds = gdal.Open( gdalnumeric.GetArrayFilename(array))
    ds = gdal_array.OpenArray(array)

    if ds is not None and prototype_ds is not None:
        if type(prototype_ds).__name__ == 'str':
            prototype_ds = gdal.Open( prototype_ds )
        if prototype_ds is not None:
            gdalnumeric.CopyDatasetInfo( prototype_ds, ds, xoff=xoff, yoff=yoff )
    return ds


def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def shp2Raster(shp,templatePic,output,nodata):
    """
    shp:字符串,一个矢量,从0开始计数,整数
    templatePic:字符串,模板栅格,一个tif,地理变换信息从这里读,栅格大小与该栅格一致
    output:字符串,输出栅格,一个tif
    field:字符串,栅格值的字段
    nodata:整型或浮点型,矢量空白区转换后的值
    """
    ndsm = templatePic
    data = gdal.Open(ndsm, gdalconst.GA_ReadOnly)
    geo_transform = data.GetGeoTransform()
    proj=data.GetProjection()
    #source_layer = data.GetLayer()
    x_min = geo_transform[0]
    y_max = geo_transform[3]
    x_max = x_min + geo_transform[1] * data.RasterXSize
    y_min = y_max + geo_transform[5] * data.RasterYSize
    x_res = data.RasterXSize
    y_res = data.RasterYSize
    mb_v = ogr.Open(shp)
    mb_l = mb_v.GetLayer()
    pixel_width = geo_transform[1]
    #输出影像为24位整型
    target_ds = gdal.GetDriverByName('GTiff').Create(output, x_res, y_res, 1, gdal.GPI_RGB)
 
    target_ds.SetGeoTransform(geo_transform)
    target_ds.SetProjection(proj)
    band = target_ds.GetRasterBand(1)
    NoData_value = nodata
    band.SetNoDataValue(NoData_value)
    band.FlushCache()
    gdal.RasterizeLayer(target_ds, [1], mb_l, options=['ALL_TOUCHED=TRUE'])
 
    target_ds = None


pre_path='dataset/pre/'

mask_train_path='dataset/mask_train/'#最终输出标签文件夹

labellist = filter(lambda x: x.find('label')!=-1, os.listdir(pre_path))
list1 = list(map(lambda x: x[:], labellist))
label_name=pre_path +  list1[0]

boundarylist = filter(lambda x: x.find('.shp')!=-1, os.listdir(pre_path+'boundary/'))
list2 = list(map(lambda x: x[:], boundarylist))


imagelist = filter(lambda x: x.find('tif')!=-1, os.listdir(pre_path+'img'))
list3 = list(map(lambda x: x[:], imagelist))
img_path=pre_path +  list3[0]


"""
矢量裁剪
:param label_name: 要裁剪的矢量文件
:param boundary_name: 掩膜矢量文件
img_path: 影像
:param saveFolderPath: 裁剪后的矢量文件保存目录
:return:
"""
print('开始用矢量范围裁剪影像')
ogr.RegisterAll()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
# 载入要裁剪的矢量文件

labelData = ogr.Open(label_name)

labelLayer = labelData.GetLayer( os.path.split( os.path.splitext( label_name )[0] )[1] )

spatial = labelLayer.GetSpatialRef()
geomType = labelLayer.GetGeomType()


# 载入掩膜矢量文件

def new_func(outLayerName):
    return outLayerName

for i in list2:
    boundary_name=pre_path+'boundary/'+ i
    maskData = ogr.Open(boundary_name)
    maskLayer = maskData.GetLayer()
    #裁剪shp
    # 生成裁剪后的矢量文件
    save_shp_dir='./dataset/pre/shp/'
    if not os.path.exists(save_shp_dir):
        os.mkdir(save_shp_dir)
    outLayerName = (save_shp_dir+i)
    gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
    driver = ogr.GetDriverByName("ESRI Shapefile")
    outData = driver.CreateDataSource(outLayerName)
    outLayer = outData.CreateLayer(new_func(outLayerName), spatial, geomType)
    labelLayer.Clip(maskLayer, outLayer)

    lyr = maskData.GetLayer( os.path.split( os.path.splitext( boundary_name )[0] )[1] )
    shpminX, shpmaxX, shpminY, shpmaxY = lyr.GetExtent()



    #裁剪tif
    flag=0
    for j in list3:
        raster_path = pre_path+'img/'+j
        srcImage = gdal.Open(raster_path)
        geocd = srcImage.GetGeoTransform()
        geoProj = srcImage.GetProjection()
        row = srcImage.RasterXSize
        line = srcImage.RasterYSize
        tifxmin = geocd[0]
        tifymin = geocd[3]
        tifxmax = geocd[0] + (row) * geocd[1] + (line) * geocd[2]
        tifymax = geocd[3] + (row) * geocd[4] + (line) * geocd[5]
        if shpminX>=tifxmin and shpmaxX<=tifxmax and shpminY<=tifymin and shpmaxY>=tifymax:
            ulX, ulY = world2Pixel(geocd, shpminX, shpmaxY)
            lrX, lrY = world2Pixel(geocd, shpmaxX, shpminY)
            # Calculate the pixel size of the new image
            pxWidth = int(lrX - ulX)
            pxHeight = int(lrY - ulY)
            clip = srcImage.ReadAsArray(ulX,ulY,pxWidth,pxHeight)   #***只读要的那块***
            xoffset =  ulX
            yoffset =  ulY
            geoTrans = list(geoTrans)
            geoTrans[0] = shpminX
            geoTrans[3] = shpmaxY
            save_path='dataset/sat_train/'+i[:-4]+'.tif'
            write_img(save_path, geoProj, geoTrans, clip)
            gdal.ErrorReset()
            outData.Release()
            maskData.Release()
            flag=1
            output = mask_train_path + i[:-4] +'.tif'
            nodata=0
            shp2Raster(outLayerName,save_path,output,nodata)
    if flag==0:
        print(raster_path+"没有制作")
    else:
        print(raster_path)
labelData.Release()
    




 做了一半的

import os
import os
import numpy as np
from osgeo import gdal, gdalnumeric, ogr, osr, gdal_array
gdal.UseExceptions()

def world2Pixel(geoMatrix, x, y):
  """
  Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
  the pixel location of a geospatial coordinate
  """
  ulX = geoMatrix[0]
  ulY = geoMatrix[3]
  xDist = geoMatrix[1]
  yDist = geoMatrix[5]
  rtnX = geoMatrix[2]
  rtnY = geoMatrix[4]
  pixel = int((x - ulX) / xDist)
  line = int((ulY - y) / xDist)
  return (pixel, line)

#
#  EDIT: this is basically an overloaded
#  version of the gdal_array.OpenArray passing in xoff, yoff explicitly
#  so we can pass these params off to CopyDatasetInfo
#
def OpenArray( array, prototype_ds = None, xoff=0, yoff=0 ):
    # ds = gdal.Open( gdalnumeric.GetArrayFilename(array))
    ds = gdal_array.OpenArray(array)

    if ds is not None and prototype_ds is not None:
        if type(prototype_ds).__name__ == 'str':
            prototype_ds = gdal.Open( prototype_ds )
        if prototype_ds is not None:
            gdalnumeric.CopyDatasetInfo( prototype_ds, ds, xoff=xoff, yoff=yoff )
    return ds


def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

pre_path='dataset/pre/'
labellist = filter(lambda x: x.find('label')!=-1, os.listdir(pre_path))
list1 = list(map(lambda x: x[:], labellist))
label_name=pre_path +  list1[0]

boundarylist = filter(lambda x: x.find('.shp')!=-1, os.listdir(pre_path+'boundary/'))
list2 = list(map(lambda x: x[:], boundarylist))


imagelist = filter(lambda x: x.find('tif')!=-1, os.listdir(pre_path+'img'))
list3 = list(map(lambda x: x[:], imagelist))
img_path=pre_path +  list3[0]


"""
矢量裁剪
:param label_name: 要裁剪的矢量文件
:param boundary_name: 掩膜矢量文件
img_path: 影像
:param saveFolderPath: 裁剪后的矢量文件保存目录
:return:
"""
print('开始用矢量范围裁剪影像')
ogr.RegisterAll()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
# 载入要裁剪的矢量文件

labelData = ogr.Open(label_name)

labelLayer = labelData.GetLayer( os.path.split( os.path.splitext( label_name )[0] )[1] )

spatial = labelLayer.GetSpatialRef()
geomType = labelLayer.GetGeomType()


# 载入掩膜矢量文件

def new_func(outLayerName):
    return outLayerName

for i in list2:
    boundary_name=pre_path+'boundary/'+ i
    maskData = ogr.Open(boundary_name)
    maskLayer = maskData.GetLayer()
    #裁剪shp
    # 生成裁剪后的矢量文件
    save_shp_dir='./dataset/pre/shp/'
    if not os.path.exists(save_shp_dir):
        os.mkdir(save_shp_dir)
    outLayerName = (save_shp_dir+i)
    gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
    driver = ogr.GetDriverByName("ESRI Shapefile")
    outData = driver.CreateDataSource(outLayerName)
    outLayer = outData.CreateLayer(new_func(outLayerName), spatial, geomType)
    labelLayer.Clip(maskLayer, outLayer)
    
    lyr = maskData.GetLayer( os.path.split( os.path.splitext( boundary_name )[0] )[1] )
    shpminX, shpmaxX, shpminY, shpmaxY = lyr.GetExtent()



    #裁剪tif
    flag=0
    for j in list3:
        raster_path = pre_path+'img/'+j
        srcImage = gdal.Open(raster_path)
        geocd = srcImage.GetGeoTransform()
        geoProj = srcImage.GetProjection()
        row = srcImage.RasterXSize
        line = srcImage.RasterYSize
        tifxmin = geocd[0]
        tifymin = geocd[3]
        tifxmax = geocd[0] + (row) * geocd[1] + (line) * geocd[2]
        tifymax = geocd[3] + (row) * geocd[4] + (line) * geocd[5]
        if shpminX>=tifxmin and shpmaxX<=tifxmax and shpminY<=tifymin and shpmaxY>=tifymax:
            ulX, ulY = world2Pixel(geocd, shpminX, shpmaxY)
            lrX, lrY = world2Pixel(geocd, shpmaxX, shpminY)
            # Calculate the pixel size of the new image
            pxWidth = int(lrX - ulX)
            pxHeight = int(lrY - ulY)
            clip = srcImage.ReadAsArray(ulX,ulY,pxWidth,pxHeight)   #***只读要的那块***
            xoffset =  ulX
            yoffset =  ulY
            geocd = list(geocd)
            geocd[0] = shpminX
            geocd[3] = shpmaxY
            save_path='dataset/sat_train/'+i[:-4]+'.tif'
            write_img(save_path, geoProj, geocd, clip)
            gdal.ErrorReset()
            outData.Release()
            maskData.Release()
            flag=1
    if flag==0:
        print(raster_path+"没有制作")
    else:
        print(raster_path)
labelData.Release()
    




版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_61235989/article/details/130709482

智能推荐

【新手科研指南5】深度学习代码怎么读-小白阶段性思路(以手写数字识别应用为例)_深度学习程序怎么读-程序员宅基地

文章浏览阅读6.2k次,点赞6次,收藏26次。我是一个深度学习代码小白,请你用中文写上注释,能让我能轻松理解下面这段代码。注意包含所有函数、调用和参数的注释。以同样的python代码块样式返回你写的代码给我。代码看累了,就看《动手学深度学习》文档:基于PyTorch框架,从底层函数实现基础功能,再到框架的高级功能。努力上路的小白一枚,麻烦路过的大佬指导一二,同时希望能和大家交流学习~争取更新学习这个文档的专栏,记录学习过程。量身定做了一套话术hhh,亲身测试还不错。这个感觉更浅一点儿,之后复习看吧。20天吃掉那只Pytorch。_深度学习程序怎么读

Java学习路线图,看这一篇就够了!-程序员宅基地

文章浏览阅读2.7w次,点赞126次,收藏1.2k次。耗废1024根秀发,Java学习路线图来了,整合了自己所学的所有技术整理出来的2022最新版Java学习路线图,适合于初、中级别的Java程序员。_java学习路线

PCL_Tutorial2-1.7-点云保存PNG_pcl::io:savepng-程序员宅基地

文章浏览阅读4.4k次。1.7-savingPNG介绍代码详情函数详解savePNGFile()源码savePNGFile()源码提示savePNGFile()推荐用法处理结果代码链接介绍PCL提供了将点云的值保存到PNG图像文件的可能性。这只能用有有序的云来完成,因为结果图像的行和列将与云中的行和列完全对应。例如,如果您从类似Kinect或Xtion的传感器中获取了点云,则可以使用它来检索与该云匹配的640x480 RGB图像。代码详情#include <pcl / io / pcd_io.h>#incl_pcl::io:savepng

知乎问答:程序员在咖啡店编程,喝什么咖啡容易吸引妹纸?-程序员宅基地

文章浏览阅读936次。吸引妹子的关键点不在于喝什么咖啡,主要在于竖立哪种男性人设。能把人设在几分钟内快速固定下来,也就不愁吸引对口的妹子了。我有几个备选方案,仅供参考。1. 运动型男生左手单手俯卧撑,右手在键盘上敲代码。你雄壮的腰腹肌肉群活灵活现,简直就是移动的春药。2.幽默男生花 20 块找一个托(最好是老同学 or 同事)坐你对面。每当你侃侃而谈,他便满面涨红、放声大笑、不能自已。他笑的越弱_咖啡厅写代码

【笔试面试】腾讯WXG 面委会面复盘总结 --一次深刻的教训_腾讯面委会面试是什么-程序员宅基地

文章浏览阅读1.2w次,点赞5次,收藏5次。今天 (应该是昨天了,昨晚太晚了没发出去)下午参加了腾讯WXG的面委会面试。前面在牛客上搜索了面委会相关的面经普遍反映面委会较难,因为都是微信的核心大佬,问的问题也会比较深。昨晚还蛮紧张的,晚上都没睡好。面试使用的是腾讯会议,时间到了面试官准时进入会议。照例是简单的自我介绍,然后是几个常见的基础问题:例如数据库索引,什么时候索引会失效、设计模式等。这部分比较普通,问的也不是很多,不再赘述。现在回想下,大部分还是简历上写的技能点。接下来面试官让打开项目的代码,对着代码讲解思路。我笔记本上没有这部分代码,所_腾讯面委会面试是什么

AI绘画自动生成器:艺术创作的新浪潮-程序员宅基地

文章浏览阅读382次,点赞3次,收藏4次。AI绘画自动生成器是一种利用人工智能技术,特别是深度学习算法,来自动创建视觉艺术作品的软件工具。这些工具通常基于神经网络模型,如生成对抗网络(GANs),通过学习大量的图像数据来生成新的图像。AI绘画自动生成器作为艺术与科技结合的产物,正在开启艺术创作的新篇章。它们不仅为艺术家和设计师提供了新的工具,也为普通用户提供了探索艺术的机会。随着技术的不断进步,我们可以预见,AI绘画自动生成器将在未来的创意产业中发挥越来越重要的作用。

随便推点

Flutter ListView ListView.build ListView.separated_flutter listview.separated和listview.builder-程序员宅基地

文章浏览阅读1.7k次。理解为ListView 的三种形式吧ListView 默认构造但是这种方式创建的列表存在一个问题:对于那些长列表或者需要较昂贵渲染开销的子组件,即使还没有出现在屏幕中但仍然会被ListView所创建,这将是一项较大的开销,使用不当可能引起性能问题甚至卡顿直接返回的是每一行的Widget,相当于ios的row。行高按Widget(cell)高设置ListView.build 就和io..._flutter listview.separated和listview.builder

2021 最新前端面试题及答案-程序员宅基地

文章浏览阅读1.4k次,点赞4次,收藏14次。废话不多说直接上干货1.js运行机制JavaScript单线程,任务需要排队执行同步任务进入主线程排队,异步任务进入事件队列排队等待被推入主线程执行定时器的延迟时间为0并不是立刻执行,只是代表相比于其他定时器更早的被执行以宏任务和微任务进一步理解js执行机制整段代码作为宏任务开始执行,执行过程中宏任务和微任务进入相应的队列中整段代码执行结束,看微任务队列中是否有任务等待执行,如果有则执行所有的微任务,直到微任务队列中的任务执行完毕,如果没有则继续执行新的宏任务执行新的宏任务,凡是在..._前端面试

linux基本概述-程序员宅基地

文章浏览阅读1k次。(3)若没有查到,则将请求发给根域DNS服务器,并依序从根域查找顶级域,由顶级查找二级域,二级域查找三级,直至找到要解析的地址或名字,即向客户机所在网络的DNS服务器发出应答信息,DNS服务器收到应答后现在缓存中存储,然后,将解析结果发给客户机。(3)若没有查到,则将请求发给根域DNS服务器,并依序从根域查找顶级域,由顶级查找二级域,二级域查找三级,直至找到要解析的地址或名字,即向客户机所在网络的DNS服务器发出应答信息,DNS服务器收到应答后现在缓存中存储,然后,将解析结果发给客户机。_linux

JavaScript学习手册十三:HTML DOM——文档元素的操作(一)_javascript学习手册十三:html dom——文档元素的操作(一)-程序员宅基地

文章浏览阅读7.9k次,点赞26次,收藏66次。HTML DOM——文档元素的操作1、通过id获取文档元素任务描述相关知识什么是DOM文档元素节点树通过id获取文档元素代码文件2、通过类名获取文档元素任务描述相关知识通过类名获取文档元素代码文件3、通过标签名获取文档元素任务描述相关知识通过标签名获取文档元素获取标签内部的子元素代码文件4、html5中获取元素的方法一任务描述相关知识css选择器querySelector的用法代码文件5、html5中获取元素的方法二任务描述相关知识querySelectorAll的用法代码文件6、节点树上的操作任务描述相关_javascript学习手册十三:html dom——文档元素的操作(一)

《LeetCode刷题》172. 阶乘后的零(java篇)_java 给定一个整数n,返回n!结果尾数中零的数量-程序员宅基地

文章浏览阅读132次。《LeetCode学习》172. 阶乘后的零(java篇)_java 给定一个整数n,返回n!结果尾数中零的数量

php 公众号消息提醒,如何开启公众号消息提醒功能-程序员宅基地

文章浏览阅读426次。请注意,本文将要给大家分享的并不是开启公众号的安全操作风险提醒,而是当公众号粉丝给公众号发消息的时候,公众号的管理员和运营者如何能在手机上立即收到消息通知,以及在手机上回复粉丝消息。第一步:授权1、在微信中点击右上角+,然后选择“添加朋友”,然后选择“公众号”,然后输入“微小助”并关注该公众号。2、进入微小助公众号,然后点击底部菜单【新增授权】,如下图所示:3、然后会打开一个温馨提示页面。请一定要..._php微信公众号服务提示

推荐文章

热门文章

相关标签