00001 __copyright__ = """
00002 Copyright 2008 Sean Ross-Ross
00003 """
00004 __license__ = """
00005 This file is part of SLIMpy .
00006
00007 SLIMpy is free software: you can redistribute it and/or modify
00008 it under the terms of the GNU Lesser General Public License as published by
00009 the Free Software Foundation, either version 3 of the License, or
00010 (at your option) any later version.
00011
00012 SLIMpy is distributed in the hope that it will be useful,
00013 but WITHOUT ANY WARRANTY; without even the implied warranty of
00014 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
00015 GNU Lesser General Public License for more details.
00016
00017 You should have received a copy of the GNU Lesser General Public License
00018 along with SLIMpy . If not, see <http://www.gnu.org/licenses/>.
00019 """
00020
00021 from numpy import ndarray, array,zeros_like
00022 from itertools import izip,starmap,cycle
00023 from pdb import set_trace
00024
00025
00026 class AugmentBase( ndarray ):
00027 'base class for augmented vector and operator'
00028 _contained_type = None
00029
00030 def __new__(cls, *args, **kw ):
00031 kw['dtype'] = object
00032 return array( *args, **kw).view(cls)
00033
00034 def __array_finalize__(self,obj):
00035
00036 if hasattr(obj,"_contained_type"):
00037 self._contained_type = type(self)._contained_type
00038
00039 if hasattr(obj,"meta") and obj.meta is not None:
00040 self.meta = obj.meta.copy()
00041
00042
00043 def __array__(self):
00044 return self.view( ndarray )
00045
00046 def __array_wrap__(self, obj):
00047 new_obj = self.view( self.__class__ )
00048 new_obj._contained_type = self._contained_type
00049 return new_obj
00050
00051 def for_each( self ):
00052 return self.ravel().__iter__( )
00053
00054 def __init__(self,*args,**kw):
00055
00056 return ndarray.__init__( self, *args,**kw )
00057
00058 def __attr_func__( self, attr, pkw_obj ):
00059
00060 new_array = zeros_like(self).view( self.__class__ )
00061
00062 for i in range(self.size):
00063 item = self.flat[i]
00064 if isinstance(item, self._contained_type):
00065 method = getattr(item, attr)
00066 p,kw = pkw_obj.next()
00067 new_array.flat[i] = method( *p, **kw )
00068 else:
00069 new_array.flat[i] = item
00070
00071 new_array = new_array.view( self.__class__ )
00072
00073 if hasattr(self, 'meta' ):
00074 try:
00075 new_array.meta = self.meta
00076 except AttributeError:
00077 pass
00078
00079 return new_array
00080
00081 def __pk_helper__( self, _iter, itm):
00082 """
00083 aug.__pk_helper__( _iter, itm ) -> None
00084 pk_helper performs an inplace change of itm
00085 _iter is an iterator that loops through (key,value)
00086 pairs of itm.
00087 if array(val) returns an array of size 0 then itm[key]
00088 is replaced with a 1D array of 'value' of len self.size
00089 if array(val) returns an array of size equal to self.size
00090 then itm[key] is replaced with a 1D array(val).
00091 """
00092
00093 for key,val in _iter:
00094 array_val = array(val)
00095 if not array_val.shape:
00096 array_val = array( [val]*self.size )
00097 array_val = array_val.ravel( )
00098 if not self.size == array_val.size:
00099 msg = "'augmatrix size' %s != 'parameter size' %s,\n\tFor parameter '%s'='%s'" %(self.size, array_val.size,key,val)
00100 raise ValueError(msg)
00101
00102 itm[key] = array_val
00103
00104 def __pk_expannder__( self, *p, **kw):
00105 '''
00106 aug.__pk_expannder__( *p, **kw ) -> iterator
00107 expands takes each item element of p and k and calls
00108 uses __pk_helper__ to create an array of size of self.size
00109
00110 iterator returns tuples of ( p, kw )
00111 '''
00112
00113 if p:
00114 p = list(p)
00115 self.__pk_helper__( enumerate(p), p )
00116 p_izip = izip(*p)
00117 else:
00118 p_izip = cycle([()])
00119
00120 if kw:
00121 self.__pk_helper__( kw.iteritems(), kw )
00122 keys,values = zip( *kw.items() )
00123 vals = izip(*values)
00124 zipkw = starmap( lambda *item: dict(zip(keys,item)),vals)
00125 else:
00126 zipkw = cycle([{}])
00127
00128 return izip( p_izip,zipkw)
00129
00130 def __obj_or_array__(self,obj):
00131 if isinstance(obj, ndarray):
00132 return obj.view( ndarray ).reshape( self.shape )
00133 else:
00134 return obj
00135
00136 def __func__(self,other,name):
00137 o_obj = self.__obj_or_array__(other)
00138 self_func = getattr( self.view( ndarray ), name )
00139 new = self_func( o_obj ).view(self.__class__ )
00140 if hasattr(self,"meta") and self.meta is not None:
00141 new.meta = self.meta.copy()
00142
00143 return new
00144
00145
00146 def __add__( self, other ):
00147 return self.__func__(other, '__add__')
00148
00149 def __radd__( self, other ):
00150 return self.__func__(other, '__radd__')
00151
00152 def __sub__( self, other ):
00153 return self.__func__(other, '__sub__')
00154
00155 def __rsub__( self, other ):
00156 return self.__func__(other, '__rsub__')
00157
00158 def __div__( self, other ):
00159 return self.__func__(other, '__div__')
00160
00161 def __rdiv__( self, other ):
00162 return self.__func__(other, '__rdiv__')
00163
00164 def __mul__( self, other ):
00165 return self.__func__(other, '__mul__')
00166
00167 def __rmul__( self, other ):
00168 return self.__func__(other, '__rmul__')
00169
00170 def __pow__( self, other ):
00171 return self.__func__(other, '__pow__')
00172
00173 def __neg__( self ):
00174 new = self.view( ndarray ).__neg__( ).view(self.__class__ )
00175 if hasattr(self,"meta") and self.meta is not None:
00176 new.meta = self.meta.copy()
00177 return new
00178
00179 def __abs__( self ):
00180 new = self.view( ndarray ).__abs__().view(self.__class__ )
00181 if hasattr(self,"meta") and self.meta is not None:
00182 new.meta = self.meta.copy()
00183 return new
00184
00185
00186