00001 """
00002 defin an window/ cat operator
00003 """
00004
00005 __copyright__ = """
00006 Copyright 2008 Sean Ross-Ross
00007 """
00008 __license__ = """
00009 This file is part of SLIMpy .
00010
00011 SLIMpy is free software: you can redistribute it and/or modify
00012 it under the terms of the GNU Lesser General Public License as published by
00013 the Free Software Foundation, either version 3 of the License, or
00014 (at your option) any later version.
00015
00016 SLIMpy is distributed in the hope that it will be useful,
00017 but WITHOUT ANY WARRANTY; without even the implied warranty of
00018 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
00019 GNU Lesser General Public License for more details.
00020
00021 You should have received a copy of the GNU Lesser General Public License
00022 along with SLIMpy . If not, see <http://www.gnu.org/licenses/>.
00023 """
00024
00025
00026 from slimpy_base.Core.User.linop.rclinOp import LinearOperator
00027 from slimpy_base.User.AumentedMatrix.MetaSpace import MetaSpace
00028 from slimpy_base.api.linearops.linear_ops import pad
00029 from slimpy_base.User.AumentedMatrix.AugVector import AugVector
00030 from slimpy_base.utils.permute import permute
00031 from slimpy_base.api.functions.functions import cat
00032
00033 from numpy import array, all, clip, size,prod
00034 from numpy import asarray, zeros, product
00035 from slimpy_base.Core.User.linop.linear_operator import Identity
00036
00037
00038
00039 def scatter_gen( inspace, numblocks, overlap ):
00040 """
00041 genorator scatter_gen(inspace, numblocks, overlap)
00042 yeilds pairs of gather,scatter padding keywork arguments.
00043
00044 where gather_kw is
00045 """
00046 shape = inspace.shape
00047
00048 numblocks = list( numblocks )
00049 overlap = list( overlap )
00050
00051
00052 diff = len( shape ) - len( numblocks )
00053 assert diff >= 0 , "the parameter numblocks must be a list smaller than the shape of the vector"
00054 if diff > 0:
00055
00056 numblocks.extend( [1]*diff )
00057
00058
00059
00060 diff = len( shape ) - len( overlap )
00061 assert diff >= 0 , "the parameter overlap must be a list smaller than the shape of the vector"
00062 if diff > 0:
00063
00064 overlap.extend( [0]*diff )
00065
00066 overlap = array( overlap )
00067
00068
00069
00070
00071 for brk, ln in zip( numblocks, shape ):
00072 assert brk > 0 , "numblocks must be greater than zero: got %(brk)s" %vars()
00073 assert ln % brk == 0 , "the numblocks must divide without remainder the length of the data: got %(ln)s % %(brk)s" %vars()
00074
00075
00076 w_size = array( shape ) / array( numblocks )
00077 assert all( w_size >= overlap )
00078
00079 name_kw = lambda name, i, item : ( "%s%s"%( name, i+1 ), item )
00080
00081 xblocksize = map( xrange, numblocks )
00082 enum_perm = lambda xbs: enumerate( permute( xbs ) )
00083
00084 for i, block in enum_perm( xblocksize ):
00085
00086 begin_tmp = block * w_size
00087 end_tmp = shape - ( begin_tmp + w_size )
00088
00089 begin = begin_tmp - overlap/2
00090 end = end_tmp - overlap/2
00091
00092 begin = clip( begin , 0, shape )
00093 end = clip( end , 0, shape )
00094
00095 kw = {}
00096 kw.update( [ name_kw( "beg", i, item ) for i, item in enumerate( begin ) ] )
00097 kw.update( [ name_kw( "end", i, item ) for i, item in enumerate( end ) ] )
00098
00099
00100 begin_diff = begin_tmp - begin
00101 end_diff = end_tmp - end
00102 overlap_kw = { }
00103 overlap_kw.update( [ name_kw( "beg", i, item ) for i, item in enumerate( begin_diff ) ] )
00104 overlap_kw.update( [ name_kw( "end", i, item ) for i, item in enumerate( end_diff ) ] )
00105
00106
00107
00108
00109 yield overlap_kw, kw
00110
00111 return
00112
00113
00114 def Scatterf( inspace, numblocks, overlap ):
00115 '''
00116 Scatterf(inspace, numblocks,overlap) -> ol_list, padlst
00117
00118
00119 @param numblocks: number of blocks in each dimention, may be a list
00120 no longer than space.shape
00121 @type numblocks: list
00122 @param overlap: the total overlap for each adjasent block.
00123 @type overlap: list
00124
00125 @return: ol_list: a list of key word arguments that may be used in a
00126 gather operator to window the overlap for each block.
00127 padlst: a list of window operators that may operate on
00128 the data.
00129 '''
00130
00131 padlst = []
00132 ol_lst = []
00133 for overlap_kw, kw in scatter_gen( inspace, numblocks, overlap ):
00134 P = pad( inspace, adj=True, **kw )
00135 padlst.append( P )
00136
00137
00138 ol_lst.append( overlap_kw )
00139
00140 return ol_lst, padlst
00141
00142
00143
00144
00145 class Scatter( LinearOperator ):
00146 """
00147 S = Scatter(inspace, blocksize=None, numblocks=None,overlap=None )
00148
00149 Scatter a vector into an Aumented vector.
00150
00151 """
00152
00153 name = "scatter"
00154
00155
00156 def __new__( cls, *params, **kparams ):
00157 if 'numblocks' in kparams:
00158 numblocks = array(kparams['numblocks'])
00159 if all(numblocks==1):
00160 return Identity( params[0] )
00161
00162 return LinearOperator.__new__(cls, *params, **kparams)
00163
00164 pass
00165
00166
00167
00168
00169
00170
00171
00172 def __init__( self, inspace, blocksize=None, numblocks=None, overlap=None ):
00173
00174 if not ( bool( blocksize ) ^ bool( numblocks ) ):
00175 raise TypeError( "must use either blocksize or numblocks, not both" )
00176 if not overlap:
00177 overlap = [ ]
00178
00179 if blocksize:
00180 numblocks = []
00181 for block, lngth in zip( blocksize, inspace.shape ):
00182
00183 if block == 0:
00184 numblocks.append( 1 )
00185 else:
00186 assert lngth % block == 0
00187 numblocks.append( lngth / block )
00188
00189 overlap_kw, self.pad_operlist = Scatterf( inspace, numblocks, overlap )
00190 self.numblocks = numblocks
00191 self.is_identity = prod( self.numblocks ) == 1
00192
00193 if self.is_identity:
00194 outspace = inspace
00195 else:
00196
00197 outspace = MetaSpace( [[oper.range() for oper in self.pad_operlist]] ).T
00198
00199 self.overlap_window = []
00200 for shard, kw in zip( outspace.ravel() , overlap_kw ):
00201 O = pad( shard , adj=True, **kw )
00202 self.overlap_window.append( O )
00203
00204 LinearOperator.__init__( self, inspace, outspace )
00205
00206 return
00207
00208 def applyop( self, other ):
00209 if self.is_identity:
00210 return other
00211
00212 if self.isadj:
00213 return self.applyop_adj( other )
00214 else:
00215 return self.applyop_fwd( other )
00216
00217 def applyop_fwd( self, other ):
00218
00219
00220 res = []
00221 push = res.append
00222 for P in self.pad_operlist:
00223 push( P.applyop( other ) )
00224
00225 return AugVector( [res] ).T
00226
00227 def applyop_adj( self, other ):
00228
00229 assert len( other ) == len( self.overlap_window )
00230
00231 windows = []
00232 other_array = array( other )
00233 for O, shard in zip( self.overlap_window, other_array.ravel() ):
00234 window = O*shard
00235 windows.append( window )
00236
00237 windowed_array = array( windows )
00238 other_array = windowed_array.reshape( self.numblocks )
00239
00240 dim = 0
00241 while len( other_array.shape ):
00242 dim += 1
00243 other_array = apply_along_axis( cat, 0, other_array, dim )
00244
00245 return other_array.item()
00246
00247
00248 def apply_along_axis( func1d, axis, arr, *args ):
00249 """ Execute func1d(arr[i],*args) where func1d takes 1-D arrays
00250 and arr is an N-d array. i varies so as to apply the function
00251 along the given axis for each 1-d subarray in arr.
00252 """
00253 arr = asarray( arr )
00254 nd = arr.ndim
00255 if axis < 0:
00256 axis += nd
00257 if ( axis >= nd ):
00258 raise ValueError( "axis must be less than arr.ndim; axis=%d, rank=%d."
00259 % ( axis, nd ) )
00260 ind = [0]*( nd-1 )
00261 i = zeros( nd, 'O' )
00262 indlist = range( nd )
00263 indlist.remove( axis )
00264 i[axis] = slice( None, None )
00265 outshape = asarray( arr.shape ).take( indlist )
00266 i.put( indlist, ind )
00267 res = func1d( arr[tuple( i.tolist() )], *args )
00268
00269
00270
00271 outarr = zeros( outshape, asarray( res ).dtype )
00272 outarr[tuple( ind )] = res
00273 Ntot = product( outshape )
00274 k = 1
00275
00276 while k < Ntot:
00277
00278 ind[-1] += 1
00279 n = -1
00280 while ( ind[n] >= outshape[n] ) and ( n > ( 1-nd ) ):
00281 ind[n-1] += 1
00282 ind[n] = 0
00283 n -= 1
00284 i.put( indlist, ind )
00285 res = func1d( arr[tuple( i.tolist() )], *args )
00286 outarr[tuple( ind )] = res
00287 k += 1
00288 return outarr
00289
00290
00291 class EdgeUpdate( LinearOperator ):
00292
00293 def __init__(self ,inspace, numblocks, overlap ):
00294 self.numblocks = numblocks
00295 self.overlap = overlap
00296 LinearOperator.__init__( self, inspace, inspace )
00297
00298
00299 def apply( self, other ):
00300 if self.isadj:
00301 return self.applyop_adj( other )
00302 else:
00303 return self.applyop_adj( other )
00304
00305
00306 def applyop_adj(self ,other ):
00307
00308 other_array = other.reshape( self.numblocks )
00309
00310 for dim in range( other_array.ndims ):
00311
00312 other_array = apply_along_axis( edge_swap , dim, other_array ,dim )
00313
00314
00315
00316 def edge_swap( arry, dim ):
00317
00318 lst = arry.tolist( )
00319 for i in range( len(lst) - 1 ) :
00320 a = lst[ i ]
00321 b = lst[ i+1 ]
00322
00323 W1 = pad()
00324 W2 = pad()
00325
00326 anew = W1*a
00327 bnew = W2*b
00328
00329 res1 = cat( [anew, bnew], dim )
00330
00331 W3 = pad( )
00332 W4 = pad( )
00333
00334 anew = W3*b
00335 bnew = W4*a
00336
00337 res2 = cat( [anew, bnew], dim )
00338
00339 lst[ i ] = res1
00340 lst[ i+1 ] = res2
00341
00342
00343