From 6faa07f5072360a2b2d347143210b7b5c10895aa Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 5 May 2023 14:19:13 +0800 Subject: [PATCH] initial commit --- algorithms/__init__.py | 0 .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 143 bytes algorithms/hyper_neat/__init__.py | 0 algorithms/neat/__init__.py | 1 + .../neat/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 189 bytes .../neat/__pycache__/pipeline.cpython-39.pyc | Bin 0 -> 1609 bytes .../neat/__pycache__/species.cpython-39.pyc | Bin 0 -> 4781 bytes algorithms/neat/genome/__init__.py | 4 + .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 351 bytes .../__pycache__/activations.cpython-39.pyc | Bin 0 -> 3534 bytes .../__pycache__/aggregations.cpython-39.pyc | Bin 0 -> 3120 bytes .../__pycache__/distance.cpython-39.pyc | Bin 0 -> 2413 bytes .../genome/__pycache__/forward.cpython-39.pyc | Bin 0 -> 6029 bytes .../genome/__pycache__/genome.cpython-39.pyc | Bin 0 -> 6308 bytes .../genome/__pycache__/graph.cpython-39.pyc | Bin 0 -> 5214 bytes .../genome/__pycache__/mutate.cpython-39.pyc | Bin 0 -> 15985 bytes .../genome/__pycache__/utils.cpython-39.pyc | Bin 0 -> 4599 bytes algorithms/neat/genome/activations.py | 138 +++++ algorithms/neat/genome/aggregations.py | 109 ++++ algorithms/neat/genome/crossover.py | 151 +++++ algorithms/neat/genome/distance.py | 71 +++ algorithms/neat/genome/forward.py | 171 ++++++ algorithms/neat/genome/genome.py | 195 +++++++ algorithms/neat/genome/graph.py | 198 +++++++ algorithms/neat/genome/mutate.py | 538 ++++++++++++++++++ algorithms/neat/genome/utils.py | 134 +++++ algorithms/neat/pipeline.py | 41 ++ algorithms/neat/species.py | 190 +++++++ algorithms/neat/stagnation.py | 62 ++ algorithms/numpy/__init__.py | 5 + algorithms/numpy/distance.py | 58 ++ algorithms/numpy/utils.py | 55 ++ examples/__init__.py | 0 examples/genome_test.py | 71 +++ examples/jax_playground.py | 37 ++ examples/xor.py | 40 ++ utils/__init__.py | 1 + utils/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 177 bytes utils/__pycache__/config.cpython-39.pyc | Bin 0 -> 2914 bytes utils/__pycache__/dotdict.cpython-39.pyc | Bin 0 -> 1754 bytes utils/config.py | 78 +++ utils/default_config.json | 108 ++++ utils/dotdict.py | 61 ++ 43 files changed, 2517 insertions(+) create mode 100644 algorithms/__init__.py create mode 100644 algorithms/__pycache__/__init__.cpython-39.pyc create mode 100644 algorithms/hyper_neat/__init__.py create mode 100644 algorithms/neat/__init__.py create mode 100644 algorithms/neat/__pycache__/__init__.cpython-39.pyc create mode 100644 algorithms/neat/__pycache__/pipeline.cpython-39.pyc create mode 100644 algorithms/neat/__pycache__/species.cpython-39.pyc create mode 100644 algorithms/neat/genome/__init__.py create mode 100644 algorithms/neat/genome/__pycache__/__init__.cpython-39.pyc create mode 100644 algorithms/neat/genome/__pycache__/activations.cpython-39.pyc create mode 100644 algorithms/neat/genome/__pycache__/aggregations.cpython-39.pyc create mode 100644 algorithms/neat/genome/__pycache__/distance.cpython-39.pyc create mode 100644 algorithms/neat/genome/__pycache__/forward.cpython-39.pyc create mode 100644 algorithms/neat/genome/__pycache__/genome.cpython-39.pyc create mode 100644 algorithms/neat/genome/__pycache__/graph.cpython-39.pyc create mode 100644 algorithms/neat/genome/__pycache__/mutate.cpython-39.pyc create mode 100644 algorithms/neat/genome/__pycache__/utils.cpython-39.pyc create mode 100644 algorithms/neat/genome/activations.py create mode 100644 algorithms/neat/genome/aggregations.py create mode 100644 algorithms/neat/genome/crossover.py create mode 100644 algorithms/neat/genome/distance.py create mode 100644 algorithms/neat/genome/forward.py create mode 100644 algorithms/neat/genome/genome.py create mode 100644 algorithms/neat/genome/graph.py create mode 100644 algorithms/neat/genome/mutate.py create mode 100644 algorithms/neat/genome/utils.py create mode 100644 algorithms/neat/pipeline.py create mode 100644 algorithms/neat/species.py create mode 100644 algorithms/neat/stagnation.py create mode 100644 algorithms/numpy/__init__.py create mode 100644 algorithms/numpy/distance.py create mode 100644 algorithms/numpy/utils.py create mode 100644 examples/__init__.py create mode 100644 examples/genome_test.py create mode 100644 examples/jax_playground.py create mode 100644 examples/xor.py create mode 100644 utils/__init__.py create mode 100644 utils/__pycache__/__init__.cpython-39.pyc create mode 100644 utils/__pycache__/config.cpython-39.pyc create mode 100644 utils/__pycache__/dotdict.cpython-39.pyc create mode 100644 utils/config.py create mode 100644 utils/default_config.json create mode 100644 utils/dotdict.py diff --git a/algorithms/__init__.py b/algorithms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithms/__pycache__/__init__.cpython-39.pyc b/algorithms/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de65dd58d42c45c99fe5aae4128ad2e915a12fd GIT binary patch literal 143 zcmYe~<>g`k0!xR$6cGIwL?8o3AjbiSi&=m~3PUi1CZpdg`kg5;c#6geRM7{oyaOhAqU5Elyoi4=wu#vF!R#wbQch7_h?22JLdj6h*c z##}r8(+-Ri|;uVE>}G zuWI!&Q+#kN@wnYFBys{)K@!SI!Z=MFOy&(qT;Xx|9kmyC6Q2h>1Z%)YaQVoDzvPW5 zb1Bx3q8P>kE*fX$MXhpuS;bB7J+aOu+c_9@ALbd%R)N^i4M{;3v|yk)+TGA=w0YX^ zwO~VTf?M*N3)T?k2xbIxPmeLW0KM4bl(Vi|6j?(Z(1 zy{E=I&-C(ASq2^$e_6{_eu36DUnYkLdRk>ydDT^^ths0nT_{A&9JX7sQV8kUQZ6`> z{d-KAQBya4@II;#AcZBcHKSZMolbLpWum&%y{@3>NV2y>tBEdz6e>g7lh7UwjL(>H zTT!kQmg{?>9>U*0-_I+l=VC5_jc<$WYEXWKkKgA5Kdp7l2bP*wCQNOarRf+pY9D|g z)Ta@PXhe_c5%byl;EwZ?yR|BWWuLH%_@L3U`v??G0D0sHxGjK}0vIa14dwK$|C((W zu))@U*)qX8IR+@BzieTC>-Ttlrw-xIc#X=XHm-iv2sHsqa@;i> zADX+;9}=H2n~50>miw8wT9oBCut|%M5ITk5dNL$}-o-4$3Y1wT($qw0TGhNOQ6H!2 zA6-`Vo*+$my-ZVe1RELmysk@i466DX1@^#>Kp0g30;h#Bq^=uIrhC>3M>^0JNzkWc odpq%=+3HHzezv{1n+a literal 0 HcmV?d00001 diff --git a/algorithms/neat/__pycache__/species.cpython-39.pyc b/algorithms/neat/__pycache__/species.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99958b9503d51f2a8adcd3a7768933dcea8f16dd GIT binary patch literal 4781 zcmbtY-H#hr6`y-&JRXm|UaxmI$)*8}C=dhPKq*p*R;5)#k)Rk*BB%zU(ZqM`9cMiD z-Wey^)mTWRQe31;LE?=fdBs!y2mS~iqx*`4)Mws!p#07q+w0AyePO(F=6;@gzR$TE zY;3p&p1(c3Bho8|@lWckKNdP~<4tn_VF(r*L*}s|_xRBCOh#iowuU9Iq}yg}dp5?* zxa>LjT5)CQdM@Tl!j7xMnpb1SPYh8O&WDC@!i~qKS4Y1hT=d^@pSLHs17AfY9)mFgz}p;2#>EDNJGEYZ2Y}mP8peII4Vm9Co8H?V8xL ze5g&x-o=}~0g$uY@Qs`UG7W(iZL@CxT7qCnLtBrOg#qNegTD^WFz<)Qs@#=fkcGnA zh=Vlqqv2SN9)v>GWjK~$8YWqgMGrz%9)`nvp-fe|7iCG9rd}fugGq|5{$fOJE*kzJ zL@|=bB*Zj~dy;go+S(2QdF~~_0dbt76 zHYKe(8_*mSU>g(4o{R+qc?s>5teNOJsV`|LCcC7)%ZD}*b~;jTBE{aVP1y_#>=}RJN7n5^xSF&!SA74wiVdGJRas~!OxLJ( z4`S%5I~tDPdIc!|jn_D7FRAW_*;2V-x_=OlxbH7cY%D5T3Rj} z=aA5fR7-)K?7sAvPTMB!Cc!ZDedYT8a3m%%wQIhAGzsEjMv}xNNn3u20NK+Tprd64 zfJ`s70UYkQj-|)C=$?B@=)H|MJpefKPN-@_JEUS>%?gydY3rCM$va-izWWJNC0fgondXEF?)!Z7i1gk$4nJwFPG% zL4(=IU+y>I)5>}<493!-9b2QRJA-JG+yRTu#i|Qo{pz*X0K<-xO zE`R6`Gnzo${hr$(*3*5iA|w2?Q!)7E_^pXa5#hLJ7M zwY6u=n=`CME(U!w1N1%b8`EaCxs;VY2AIY`8TGQ@-k+wfbL9V(whc03N>`+ri}GcL&ym6=8YS#28x_~&btwfrV2 zsr?MEyasTxe=#cP=!Crevg~$&Op04N^S+9P&b-%k-Nd!et~-P3&VxzD7U?IGw+A)*_NScy$$FbN(6RW(AL;^a42)oHRKU3BR}ykl(`+ z`F#TCPdcO_63zwelsz@}@1UWvzY1neFv}?z5S!W@iE+*Ti%iqLCDTkI>aysIxsI@5 z#dW8BP4Q%`tllJ!RVh6T#wf$@jYhGusQ!>&$7*GvhD}u|%EDn<)Mg}V*`_&Hgb$)_ z=%WBp&0drUe~2Xa5Cx?_4AT3+nFZbZirrBq86^Es*@5h1iTnobwuA9FOoTGyFj1v| zN`1Mc4-#9vMe%{w;|W_l&UOqCu(wH42josU|R*U2C}nHAz-=6Yfz6ymm|ft z2SMz)nl?;D!fOx}%f1%$4~R>a>PkUq?*mQ#0dpGdJf} zj$C@ma?>y6=D?mKq7O=xF^*h8g*!j-S8=zf{&oCZ6yX#)!=QV3F?JTgdsVBS1<{)? zMpfM5w1vnYflm2D0A&rsASnvsIt^|Tcn_d$tySpSwthiVIu@(~g1S-=x>Em&ijS~Z zSJ8&uwYP=HFB6+P&UgN)z%vy3f+q4i(>)~OD znp&Uh1kUuZQil90CJ0+fR~Uy`mQ&mDoT?tj&Dumr zg1!;RBsthQF_UDeD%WBr9H@;DqI0TP`KdYO=$ko5RRT-KeQ_5{U?JLsN|2!piOQb* zaWVcIeB=iJo=fN9#*ee*&uR3GaPoQi4!~dW7Pg{YT#vPz%0}jolD;GpkW@lTqEb~5 zx@Cr_p6Yu<5xGmrWH>(7^msMh-e2t2610oUMRAR=NB0IO-HQlJGLye1K$NUew-T#z b4vGuqTMR4oe-3t7O^1mZtC@|R`tSb-wZ2!a literal 0 HcmV?d00001 diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py new file mode 100644 index 0000000..46fbb83 --- /dev/null +++ b/algorithms/neat/genome/__init__.py @@ -0,0 +1,4 @@ +from .genome import create_initialize_function +from .distance import distance +from .mutate import create_mutate_function +from .forward import create_forward_function diff --git a/algorithms/neat/genome/__pycache__/__init__.cpython-39.pyc b/algorithms/neat/genome/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f27769da0b1f0024195c2e8e77e29f508eab0d27 GIT binary patch literal 351 zcmYjMOHRWu5Vf7ejjDDF5;w>KykS|Po}ets$_QCj>=sx4Ajc8V8^MvXWyKX(F;1hZ zWBK_#y_xZhTrTGb;_>jLzTy4J;xY(`HQeqVKp}-Ka+qVBIkChMPDWnX!j-%nd1)(0 zIR##AQA2+*ls0&{X`_}&3*${P(wd_dyWY3S1P^YlOqZl@HSnt`)AcDATx4BO&Mw4v z8P(a^G_!&3wGU3mJaQ}mI6XCpC5T}*)x-+a7m9UG%kZ~m&L^gshw*L?^5aaNLA)?& k_aU0}>blWjr*D`T7ZKv_?Q^JCE~wt>hrDV>0Twu)e^PN@Qvd(} literal 0 HcmV?d00001 diff --git a/algorithms/neat/genome/__pycache__/activations.cpython-39.pyc b/algorithms/neat/genome/__pycache__/activations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e75904ac9e60328c3c24faf0a9f3d3c8ca3ba1a5 GIT binary patch literal 3534 zcmbtWOK%%h7@eo>#Lmlk_eJ`kt^0^eO3R~#rm7GW6@zHR!W9bB+>7JXcFdh|>V%X{ z+4cuev8b03g8l&(EC7iWJCJ6FSdD}buq=ykzHdB^Q%{5t6Zz)cdED=OuXC?6Fwn2z zy7=35@qJ3uex*tGHHyhRe(tiaX+jfvMJtB?^{P=cb;3qPt6D`%*Q6y(VSTFI*NZl= zEfS0qz)6u}oCNL@{ftw|| z__@b1c-kX<6Nwo^B81s8?`gsk7Va9GT3T!AliGd_OWK8iJ`29x-rf!_6QY&2OjxMZ zeQT*wuKT8J*0*$D-`LXB0G>X%mao>lyv*06E_AF#QL+jp zT1nHac|~g~axiBq)Ul@eF!+||)RtA6=HyQ}4APSnJt0OJ%yt-!VD`TlDU47Z+;1y}8WbX%rn{P}(Un!bB-pquLk}&q*(x=PUh91m=UksR9gqu97sw zMW?UiG+dV`P2rVxcm{PffmshFa+a%+LK#tN88B}Y6ct&{;FS2us?#i28`W5oDJ3fn zvUmtzxI|Ag6j|@FN37XicY)loHaP&Ba9kKuTFbnN;tuUAs=BSPE`0m$vQ$!$>c#R% zwrL#23lR-6w2KDOqeYpD*cGsnJTC%SpV`PJ*SYgP84NW8wzYJhM3mKx~9@&uB9;!HD@_G{$zqWKBgJ z?Luj8Wb3Cde*EJ~>CZFGFCTyRBqYnBQ{*A~im3W##i{zXQ>iaIJ$&6Z-&wrqQhYSS z(!qEJXcuFuicT6&VG@;1w}|1MA#9AOX_B!PBo6UR6rSp_(!T9l`04=*6x)=5m%Uf*yrk9YQx@5JT@3MM4&MGltkEnbEYJ|56v@IoX zZ`YFL+8TNF8s3V$Isx31M>6})j#(&TF3v19B!`t$ zCNxW~s-u`jVrV+I-FWx#fb6`7m!sgG2JUgL(?C&h1)3h#uu7^ zW{Wqw?097MR+rf`ab|bb3>`d1sVh-4fk5a4eOies<7&bfuc*6T8t zY;Vhs>nP={`*ykJou9*j&}ov)H(~fLPM^#E;MHpjiwkcrT)nY)@JH351UYiZ*553we7{M`%jYsz~V+$q{P;cc`KUQdG{03QloRP$L&J;Ja9J zs&ej*TxYJnwN`cN#lbbX~AYqat4==iE;)r1XrQF%7-ne(Gmo gPCvSzX#bDyMfd4h_dP4CXSXff`PaYS6BD%VKN#tGiU0rr literal 0 HcmV?d00001 diff --git a/algorithms/neat/genome/__pycache__/aggregations.cpython-39.pyc b/algorithms/neat/genome/__pycache__/aggregations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc028f794d4127c88893d80ed784801abc3336ee GIT binary patch literal 3120 zcmbVOTTdHD6rPz~dyQ`pAb=qkH@(SH0}1IxjiM+~B1NqtT8T)lHceMs&*0eDyL5IT zfXow-LLMUJ2lN5@mfzF2eeFZVPiUU{omt-?q^i2s9G|&==gfD`j9)6{8NNT>elJ(Y z82ghdqhAq~J1EgloH5B0eV@E8@&j9LY^~vNjWN~##CMvR8f^0j$6i#lX6N< z%ky$Z&dRyp9CsWuDstY;m{7cRPEE)KdEu3CC)Jd^xGr9E{iCv5(==0MTIZ>xEA(sq z|3(4hX5=NzomJz9b4Tnxdou4<)V!*w2{o=N6{cqQ#lGO|HUHJ7RU1rR{+$ozExdQ$ z!eHJlDOOpny8AHiHyT|g}neZw^I=I#BZ zP7p7tr9k=d)t0|KP#&Z5WXW$gLfwp?b|TZU)KEd#QA?vvSlrl294zahu|#-n2qvbv zF5~}ync9%l~ZM&kvaGD|| z=@}05e~FT;uN!bp&C&%d?q<;D?Ybr7EvU^vs_oN9oJTK>1Nx6KDI!J?`||i{a`1^$ zCiGOul1M|41%!5-5#E;QnymRsyrdqh--{lcc^Y(w85;kRzhE!;8dxgr=ea%6t7T*s zH7m}xoWtA!lg23M#rjeT}8VQZOr?yn!QPTPn6C+xO@!d-u zC7Zq>!{V)WSohn}?Zx5Hckq%U4PM&4kGFr+^Jp8w(hGyC3C=Mul+pnB8XbKA7taLH zAE{}{+E^NHbkrI*VlNjMirZv0GM~SlSZxR!iCBVeM5v?`gn_rKbm(>bXkA~Vsh6lD zn{Jw7H82gr{1D4k42vcpm{sIuF~jqsY?ZA!p^0_xaEy-e9!j(b(Pxq27-C@`5n+;d zk867t@neCd{75rOFUc7bDUTXEdv@p#(cn&u9;}B1oJ^i*sfU|E>`rzRPV*U3c(@rG z+j}H1Y%x~m%n-!ST>@<08luQ^yo{nhrM(Tk$n(c^+{a|xF-6faCE}W6uq93}*J`4J zJ`dq$)7aFMyAqqWsT-Q1N*}F5E6}3R9SArDjHfV^oI+3{yq>c`aM5Fhg3D^zhZ#6E zd%$IOT^#W=ybur03NE*djCXJ)W4#j}-6E?f<_?hDB%{CjBSu&{2qT@clrC)JxW_cP z7bY`s*DyW?L*Chxykx?Is6*x*3K(b|@FNQV_zEtwDthz>5J@(|In|#g<+~3ac#pn& zboZh6@SD{~1{-49YRYZH;}onx2S&%r%*zk?S- z5Z_oH!qHc#+w$9Ofcye_y7zU`C}X53J}1pUv9RF}sP!`NgvcR2^i8zXM0$JE3E;Kk z&|R~wq9o@b4pH$mMG1DH0jEUmLgILy4C@$GMRziW&|VBSJ9u?u%wKAT2sSjGjd^rv zy4}rvYVI&|4VXAh>3jpdZV_<3!0)K#b+uz)CiYgx-*C%YsvhIfvy0q`F@%zaVgIja ze~X?`&RRS4;jJW-`p(805%6Ff{x&pRNtYSix_p{hX~^B`;9b&Ix5?t@TL?Rw;}mp7 jOYnT=RB=$BLdi{{cjWhXwt&`RS>*XLX3yc;EAoE;Sfqp& literal 0 HcmV?d00001 diff --git a/algorithms/neat/genome/__pycache__/distance.cpython-39.pyc b/algorithms/neat/genome/__pycache__/distance.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1979c7e38790b51344344dcf4b5b4df548b613cf GIT binary patch literal 2413 zcmah~OK;pZ5EdzEU)Fm4*f@4#v_g96298~C`T|-sX><3X1$x*dSjZy1mMu{Wk{jEF z_T&UTCRZJeDsb3(N!$|66q(_Ju)0Adz&&|A^m-#*4VlJ>suWH9$TFdIade-PQh{Ij( z&tB%g@>zvdpL^$|x4;Q^S#3aA{Uv2LSiqXk3Ge{8#}-)&=n`9g?f`4fVJmFaVvDT( zgVR~Ngndqjl+|G*brPi=w{(vZU3$;5XjuB+N*SGD*%za1c-Eoi>OmT5%|$=Xh2XJH za-qsK#dSX)av7OjyaGNG*c<70ld749dI+6pCCL%5M|g;A^_ z27{+|^EsbQkGTJ@M|8JBoU(^gqSHi$pM)8YaLnTqo(v8(4&Wl%Px)li(la*iNdjk< zjq4*2+a|8Y`?E_+!{JUt*3n552f1KG08uv09#@*BkINv)w(ItMY^ruAUAMP-eZYnmGhQeba9ZVs;r#-54@)*2>=BHcn!I+4`USCl?Sia3ewe!fQLlQsK6po6dN(ScP__ST0tMJtHsN+a?REdvX!ZzBJstq!5;+h?-}Jefbxj zPF>cpfFdP^8epyt{>)Ep5}vmr26l2^nwJU`foC zL7>K^Le=bNkvi@LqCXO*V44|9tSOacF#XKVdaVOF+h+=>_kb0F4d(AAv|Kh8tj>Y6 zEa^>hAr^Y~de|-9XHjZ1aw2Y+TP#6rqNzI&4y}Bg@-_~X=31A=j!W}vc|+xg=n&?8g})yayJ+2#u=hP;9PuUE zq|8u{+``+&E6vJ(U-EOT{@*2G_SFUF%=V^c8;j*VBX%H+l`*W0yYjxlOp%(4+((#x zho(%TIMzw3Ix zD)#oI9&e9yf(^{y!ltGzx|R;sFZSR$6Roj^VP);>EN7#Xe`C^JnNl5qPyt-Pab$UK_XFv2UOP1{3*qEf_}j+dS$=Ts8C<`s{Tx4hWaIM)}L$4`SXo=f1$BJxU#qCFEy5+ z|GHCx|1!ZV&K0}*tZ2{Jvxiz^#bHjtp6f7s{!`geV2v{R+Y8Thd(mDxR1a8V)lux_ zXUx80uN{^3e*X--GYmL_3n7Zy2Sof;^8m+tbf^`P0yJ_E|b=PhA zfn?r3M;V(eEp{y)yOvisqgR!#f#;N+)x4O8nevhWlK@GT3u{s@k zv|>|9=i{&&dSS=d7Q3nTXAcr`FZ>#rfECdJtU=l*!-~mi~DCqI6-JTQCew* zLEyAvHw>bb)zgw2bbE2rwV$Ttuon;e8gWWBqSGqk#Rfm@7oaK7qQ3)33jB5~9jnLk zf&2${BptCM`A9iZkF<8dX108&KY-a?W-Gsu`R-_CqD4QSSgnfH6Ie}PgNH<=08~DU z0MBAZC6CnyROVyFE+M5qOMj6M*zcHK2HjPxcR>eb1AZ;;Sg~WY zL;l3#w$bhdBrQZ3w%cwy+zDbMa^s#wQY_Q18tB~e4H9zAh<2>5W31OVjCV`}_N{lr zZZmTCoehI<`7JqxWE*gL5r3F#8M(1OUKHwcFag?#Rt+EtAlYa4`d9I z2|Au5jP{JB1-eCaR^W@UKh??poktO0fFVc1d>#Q(Ka>9T{+1uaTh3MhdV6!%dYW~9 z34@QfEUy!CH{S6hn%L?%LFhYM1Kyk6y;Q&BS^lqZ1g{`?6#*XY4QSc%F3)ae$c|JF z{HxSy3aL+Hyq;%#j-LNg+|fxjP&_*TmD6QSmM)?z3{Q}|@Ic$sbevXkPi6()W!(6Ma3_b_*xPL{Ex|daC?N z?w9)Iq?9OY62xf5a)&V*9D5X?LDIsG46_};=1Kr6-QX6u-tsd%q z1+>jht5ToEVosyg3|GhUi2_Yc*eL@J(yn$4YflO|K2`d1 z0!KS5)H&v^bG#}y5MUbPUj+b8S;9Wn%}~Jw-Bf$B19d>Exlv#Rscglm5;<73DK6Gk z`wcyCo-`jrmN3;+YgwG{kvFHz6}c-5DXFsK#yg9}j%7SRm(n~c}sF$!;~>bp_&5O5JGpY0FRwdwq91YkF0#9W;ClehWD(&oq?mGU1%{v{F(zWxzp)*j=L74P$T|{n0vLDLV6S!r{#_iJjwZ_&o$8siPT;yoi4bhwlI| zt1lsbejk_r7y+5kd)WI90`jBGnfO(}{TfF!TUje+x!stDuw29ouu?84ZBu24Jk*H=YdlQym_)|fw1+0NgW_@~53d`4E zSqc+K@#2uBv&$#5JP$|(FQ7ikEdLM5Eq~Ebjump`H$hoo3@@Sd30u1j8#y!)D;GG9 z%_7Zr0ku}QI$J-J-@t>on2&PFIOjJN7^tZjzskm;Cddt`x;@} z5M@*JrLCLj2KWI5ULe8daX5GppRA#9YFJ0*mf<`6L`6k3M~WIFdqI6QB6yJjPOG@l zMEFMtFb7TsJA5ts1lrI5aSXOL&5|bwT{?XI<08u<7n>-+3AZ7TO(<%G7&9^d7uF>l zI;+j|{yh;U|HpxhEaS^$ZF$jr&hlPIgP9)sCJbgG^nD#R{wZg9I5ydR^F~^tuOwmU zMI1{=QEVyL!y;6a2lyhv3joq8eHPt}_qy-}h+__IRN)&acRJ$R8WjwUiUnWb9uF{O z+^2Q8)!m^tYI>LEGlcEMF7U-l1M6`3&bkc;5pTL!fKlG3FwN0bIG#G^Q&|-OAB5mT wSI=|qg0)K#R<2c9r(&WiS6P)Uv#KgRFR=3Ss#(SI;`!|Q>|$B;md2m|1{==4fB*mh literal 0 HcmV?d00001 diff --git a/algorithms/neat/genome/__pycache__/genome.cpython-39.pyc b/algorithms/neat/genome/__pycache__/genome.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa17460d5579388127336cbe0fafc9eb8faa5d9a GIT binary patch literal 6308 zcmb7I&5zs073Yu?B~jWBuh*Y(942v+RU*szK+<;WBz9}JP1-1F;ueSkL*R-dE!PyO z45`?=;2xTF&_n)#0&xKw=p~ol`;YX%Q-K~@13mQE`r`iH3?)jts+&^a=*^oq^Jd=n z&|O-xH9W5jZ+PYNn)X-f%swmVT)`XvnQ5A*bu^E8dccNyM`r{ZfiW~Yri9I))Ukjo z1?6F-Q;|3;s18dVTf$|)wN6dK6~OgQUBcC1X}H{39l-PAO-q$g9?4IeZdB^XX zom0H#y@L65{3a{u)RT1U_bT2A@8sRm0qdORwVbXPKZSKpduQ&N7(cUM{4BlpN15 zq<9Gw8(u?FJmZ<4XpNTpwXSLQuX!(tgumw|eiS-U-x=^Q8gfVQk>D{86FJbb?O!E+ z;P3I+xiKCE{If9fc)aOc6~f(ZI=v_iu@IJ+;BGoe*}9oGS$i!o1}FWg{Ev~J>fJmF&Khu}tYhwdGJI37CTc(}y{ZBlY~{7?>uJh>f- zTTRCeJ?FL`1Q_TE?#A5l`!md)E#AY?VmpZ0vvSYb+LbeUVhy_jn=Ft zqx zxtt;~LkXFQc~OsJS~eWDLN|ngJ){WKU|C5T7=ZG~6^ZW#ZGzj^q^Z$3%wV(&z}WE< z>Z<`gtMz%(+wRK!XSV7DZk+TG#HJ4gz8c;iQvQ4?uJ=qVpEx=bSDh(S4%^z;MP^n9yjdFf=>JXK-7tZwFRXBu}Yn~$NTO$ zNV>FswmRue7q3k)u+gmO37>EoHY@+G$OW_9aWS`>?;1c~;*@v>p9{4qK8quMB%) zeKW_}z+(ObwqH|_;sW}5T1%^5#jAX!r!}wIH!f>weTu2)7;^^0Bu!}sQyzogaF#$M}vN<($` zAFxW>!Smi!pLw>3Rs77YNX24=7)g9p4B|}*(bS3iI2(SLG&Y^~G*jdjPLaN_m)3M* zdotsL@04STiG?ptRw#H%A5}4MenRfBM`x&k!&4^iLyqTcb_pFr`sR`pH~gTf_yyK_0v5EI7b|lm~5j{EqX# zI87BfISmrw6ZkC3o!^Dg0{g zL?kwy_f$a>qZ*Q6$-CW{zt8VP$^K;pkhG{qnOCxHN!=fba{$JZJ=`P zQQ+pL^fZ}{ni8BjQWGjeI_mn~9kl_q8|tfC39d`N`jE2V-qH;eVe!yKwbtHr_KrzS zs|6?z0<|VfjeHS>d&Hi7Sicg7xBg$XteWZnf}PH}=hh-KT0q_Y7nfwqqbupo&RLp6Qui>BvH|C79y6qV`J2DmrqGqPLKNtO@<> zESgJ^o+_Qbj%aS}{Z{Ju`kfKgiACX0(Z}tf{2oRg4AL1!C)^mr2btw;+ z!OkdQZcGX%$Ec*NI9+d)~kmUclp+0{D!!=VYlXS+1$+%1X9& zDU>!!;z!tQak8vBIc+Si&_zi|)3yoGWmFf>65Xjp=5Sp1Nh%hbG0%{yo^ z`}zkT-T1WI{_xs$9R3R73j?~2ZfA*Zd0k|>fFyKC0OA@*zQ!BNZ(gNUeHDRrl}Y@K zMmejJ=A?MCauP{khQT^!X65{rt3)FO2sAyk>uC@s^F^SY^*#$$WaU-_PXx@QMO9!? zbT!N`iU$^X+@p(!JgzX0C-_}%D<{Y;iV5A^T&m({*p+yXnnG-JD)BxwB*3g9MF81l z79?;~hS~D`;=jYBspsK>#l(leAc}rKOSY+5#7p|%tn>(`SNdQRKx$A5qv(amgzE{+ zg(W&JJh)5`aOr{JAeE(tj8MAxIi|E>z*A#U2J|k_u&5_*X5LK%ne`v=#uO7Y>jIM) z7~CT9$%9BjBE~}|KwcsbLNV@;8Q}JYBdB0flCUxZ3uZuH32>wXx(ykRR2Z0smO`ON z0*thkJoe@`a60$(V%EeA$_vjy%AXO_yhyTRM^^hI=#}1O{D9s+1_nY}B zg>(VHVr`>Dnb-11Af4B80~nRK&4sMz4cXTP zLd*{}bcww&{dcl*x9df{ZWn&DJHpp2p<@~%3VpYWq9B%ETm)H$tT>SK>DyH+CnjMa zU%Gdc*^(y+vZT(r;u}bdzLK=mEw(krH`rLF1{E&J(4`#2PYI%L8%UWU@~ylq%34yE z8!|J>a4*9iz7=72(rKlcOEpTN%mhyK&?YTTUe|(Lvkr{ zDbKE~4=gW{3(Exp^wg6N*3q~A1^o+(qUdq2Jrw>AissP%-Yh?sY#cX5QJ_GVnzu7= zKHi%*^WOWd{H3L`h38i%cjRv`Th^bck$);^{0JrbBeN`NH7v;_@398MTl83;H+Ww( zMBi@MjQVWwoJPUaonEnDYLxm5jRj^Yu59|r!V@8jvh>hCWsP!l=VZ$ zOETj-s8b6)7?riqo@tFn5^uh=oza~C;CgKiO{;1}EvieZddg23EVrV{Gp!|R$rV{W zg$+%A#sP1R0~^e7plimF{Ek&yJ^7@3XV6z}Tl@PR*YC9lI*2>{$PN1YJ=IrXtfU)+ zZr4Ammp=)6$L>Ltaau~r>#c!4^tE(D6(0`tUQ{nXr^~Zd6RGG1acV#4`}?VLTWkNg z#!@Fd=_@pCcc z80RFsTZoJ2d?=(C+K;Rjla4GrEDfEw)LrP76DJXgoj`)*G**(N;%D|fJ}eB2N#ULt zieV8o0qd6h-$CuNi^q3ij_dwzPq`12j<7X%&~oF!{-8H#2Ti}{MFSnb;l>?5{&wUx zwSRK_hMBp!?`yyB8q;iMFV`n~=6iD%+jMh{{K(9Zn|Yxw3B_?n>i&TWo2qQqBz=E+ zZ?$?d@Yj-DoG~w=zo%LD9&3VQ<^Ne?QCNp+wlWJc5FOfGHPtxhR zdt=?j-%ZdtL9@N{dgXd)`itk%`csW*+T7S#i{#e*uj({03~V~7&)>9b<~-N)cctcD zzh#UFZQogUM~CwqVc*))M4z$L_K$*S#usgLI4CP9(a%8=>z;Tl9<#^rr!!7oc*-7I zkKjv>*dzW(v>bS%fJgCB8~!I9v_;Ii{HZvzVtbFn%ixpJYEhIbZ2bibo2NOh(0fMq zHW+7i7c$r2^Jnaw4UPO)tRXuS;4|V&#MdD|LrWr?AyQy8BttGql;!>xjjZ%3U;eHl}gg zbZoR$h@ie7+C5em^d+~iGX#SGEct>hJFp?8Pi`gof??Onpe)~wJ_QL zyBHD~l*G!c!dDlU#mSrhf|Br;LQjS5xKmryRgyz?N}V7I{V)}VL_5}Os+>Na63|gOQXMuo<_-N44!_IMjPwDMK!aZ<#~-S zpetGivD~X{on2ryudKes%3R+BZ^p?g#jofIIa7_{&C)}*xxY#t| z!pY$NCgCChBjJis!nGx*Bui}xr#xjpv36R+LcEY5Ku{1tQZ2^CJ+41x*8M*uHYH=Y z^kPzkOWO(H+k4iK)+;AP;!A+OA)GoX=3Ko*fFOYsexgUAgc=Df^J zT&3p3Z?coSJb!WsDZ2;lEewt5Mq6xrepOb1Yun5?IOrKGwA zVW$*m*~zZS7`r~u*_QzoGinBz(1=5#xtWhyn#|Kl&8p#kldz>LL>)1}!3Mh;!sH zlO=!|r47vI03+JCgjRy9pN$yd#eysvyy$cb-6HTJE;hzxvq}aj+QR~2$TbTlU+{OJ zqA|>e9o5`(kx*jf6xRo+KFsB3_u6=al&7;T*4)FO*K?abKshEfkDEPJ&yxPk;kZA$ z{fitFBxJJP!3=_tghLG&sK4k4Up6~;+%s#8_1(k=L~Irx3l2i5jtl^X7E$c$I0)ND zi%C#XaekS*k`3Oji{Z_O#{oB%QXEer)1 znk3dNEl*SM1UKpHG_p8N5Bhsx=38{aUdq+~TBxGMt1wZO F{R77bvM>Mu literal 0 HcmV?d00001 diff --git a/algorithms/neat/genome/__pycache__/mutate.cpython-39.pyc b/algorithms/neat/genome/__pycache__/mutate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3b12be0fea951d3b34fd9600739bfd50a1d74f GIT binary patch literal 15985 zcmdU0Ym6M(RjyaRAMW*=l(u}Rihb`24o4U=+LO;6AC zV^me+m#Rs^IO%T0!tzo`FklCeVi6%BA|d`GA(S6T5r3#35fT!S!ykSk5W9Tm+^Vjs z?jC!0NkDX~PTxBBS+~wP_dDm_HcF+ugim$nviUcklcay9oA~D>ZeGIW)D%fFB~xxm zZMi1PlvY|wTdk>_R$E$415ayd?NlwzycE)zTBexY0uVX+jF(K_Iz!g`MK7icBNL4C2PUV-MpgCI?WXH%SKG#}pL*H0jT`r6s;2tIP0Ous zZP-S~?6&=axZ7;nj_aph+j#BL<*Mdq4b$A{bWO`&pj5rv=~#8Q+3jpxxv|kSulw_+ z)v{cR%Y?k;fHzSZ!z$LdtoqJI{YJfIIW){G1_n{r=uv`C0lzvf6_@iJ4AzsjrEPgz z>1m#}tqyZc^Pt+W44N7iLDRzmXl7UnN@sg1FNM+-rVF6j@GxjCg7(`t$um{YrnKzt~^uFZU1kSNcc#tNo+>wf_3%x~Z7z zt?XNpyWlRGviLSMQ~Q?urevnxQHI%x9@4;?m&`!F+*%IwX*%Io;lZOd>iy}jqs z&*ol-#YW%k+Pc}?+_Y?~9IkT_@@us9?w-55=Wgs+H=K)cezP-`4Kdqn zHWDRSs4k8=N1wah-MtofrQ@#Cbw$Ir`|%AT>x2GK1NHHkcWx zgYw|;{)xAwdornMT1*hB8Yv<@G*lM~=lF$4e;mKqv91XvLy79TUoh(KhG}gYdo9<` z(G99MEPs}h)^5wFTVT1Ce`ri`nOO?;T9hdmjYh0B+=Q)BGTvJ8O4D$nqI}o3jE+Am zlKi?wqm?+vaq+?%?v_8t+(esMW;w3uFHTwx9`0=1$}p#5b&M-5%f#reTDB8BYd#VY zzYr56f7Z5~U5J@A8hUBO;o(=J+t?#4k8?+lu`tR>JW6rIrpH-wnw>_=+CZ(nIt0zy z5O3QrT(g>uEq8;vo)2$g9WN$%qizoc{6weufQc^eN&oS}*>=Z0Yn|;t|DN17u8$I* z$ITbc8m&gxZn|4-hcnJLtWLLWofQK+vwOqOKHoChS4`uDS4iqBAd*~`^Kw>B;}T!x z-*2803LZSsMyTLto^N&QM$35t3DHb-{nCR!(1J{}vteAfoCj;)AUo2&dFWdXdRK<# z6`+F^=-pvx-Xe7H$Y33sw+bDMYu_BSZx-4&Ggu+*v*+<#_Qyb~S}hlv^q zpBCck=V7XvSJ@`}YRl^QneHyH1u$-nhHW)ORwmGNKP^NRuL0(9gACZN-Mvzq<=enO zFsDQ>EmnYU;{{g6^wMaFH-m-9PSYHb!LJtlG-_gLiM@h3%6*9lX@~tYL|!HG8j(*C z`7{xp?oUvP=G-A|!tkpFjK z!GC=5a`l8wromnyvPfi!NF=N}-Lj0cP_vA)w6mZ`qRg_(qRc|g^2}1qV#`v@^30+< zrf^9$6Y9eH$J)c{$jU^jQKu1vA8MQ0Vu<-|a?@l}xh3^Z1Yk((r#DlkYHGK#y+;7P zQk#mIGSj$wG`P!{S=^m8b7mg(9`jCkk9d!IC%wlSsop8?_`9UA_lg3~z@)Rnd~ce36&@09mAL8MLfuKKpplY7d$awj{Q zi&@M?is#~d??O+)tW~M>K9zmC_Y6IkcusE(Px7=U?GQL6G8wo9C^V4~E~p76!`h4bsY|EzXR7f>h%K_nK7U$AiA#wR@s&5xCXaw% z)Fc-R)Fd-lph7@(U;$deCNtwJS|WEeOA|vG<%|z!l<}Y;jTf32)~LXNLkp4Q_#8!` zG4Ui}aj~pVou_1xi6I5q2M%cDtR(s!_gj*gp+_>wjbV&L<~08xQ6S*~C341{jYQ@Y zZz_==RE|WNouNdLG4v671s8Qg7Z68pI(q0c7_HWg(2+G9~TezGoNN7PK;CoN4DfW&pcYL`n0Lus@`Wbl0TX0^!P1UI#Nxr!x1T{kVL79nu zb0|&-${}1pF7rU8-r)#q1j+Xij^I%cHVK43d>QuVP^Fp@$eNdT9Sngj7&Ame%(ofh zkx}_GlnF<+wKmkqCWz9%3pQ;(Rc~m{jm6?hNT>Z#Qr=s zvZIqDiSuqh&slSMgEs(J2-HsQY7#>z3AzVcy!nxZ= zTrzyM%gK3*`W2(&8H=Xv&}iDCg5Qk|@Wg?W7nQ$HLT{6>b13IN$?bg_r+pcunnx(q zxbCNEZ9>2^uJQG_O6GHGbN+f;;;}_Gvi9lzQPiXBl2$9LnW1@kvf+^q|ap4idx1u ztCr#3RODrKUY>{3qsjB~ntW1zUS3m{)meE_UYw*#nFAMTI8EArW;4RGTONOT2l^@O zZGR{iB?#Aa@*VtkV;RpA$%swfrP5=O1KJBnouHXMjKQj`DW8x71ti9qz64 zj*#!*EqOA$67{Q!snN!Xr+S)~^3q;rKhs-9{_-dv9t-kAj9EnfQRE-?6bjaI3rsgv zlo}yWlv;B&Z)Ma2~$#M?RDsps2_bDJ-jBC90?KJ$GW7Bd>ZAa-H zcgbrw<}Cs_AD_VmWY0YCbJIq9SEf>KW;;(MJ$btjpOHI7Z}qO+lVg0I3Gk1O@%h=n zUmxT1lYy@@UqSm5GltwX`0L4?_J#=pXg<^}t8hE=*SB6@cdk!XZ^ zbr>BzrZ8vnT=-?X5d^>~pgw}s`pcZ|wsnMpIiOBq8o)5D*I|z6;0!lIl>_MgvVQ&* z-Jne%9q?D3 zd+?V8Ahp0@RBD@)+61YV>;|Q`h%`a`5Q=TKx`w+U-kCoah*h&QVrBL^O>CqRhTSrD zEqj|-eOigU`A(m5%_Q~ij z+`yuAhD1{&ARkEBE*`8_2;VYslD>nzivss?ITS{go?8VHOuw%wg06CUMP5;l%PT;S zk1NZv_^tsxa@}f{@t~jWwoICJ$JfYQWAdvBi<~SVVG%jV_tFT|BEU>x##jlo#V`sr7)F)65@3pId=Che z_26{@Q&`Jyf{}8FXknRhnNWXHTAmqqoSHA}jx;rIuxz${|Qp!H-SD z@lycD4J)RVWIW5X zOvlhhvSy@^KLH7!)|FUJNyd3~t>(uCJ%d4;s5wp2W5S(_q{;pbYK<-YNSbSf@OfY- zSc>gmr&41pQ4hS&lK|67Ae9#(z$Cu&5M!J2szEm#^%!lP`o&-iWrO@AABSD|SP#In zIJ4pxLJmYFb_0;bOAswN;fFdVfoEa=77`qq!+|C2l4-p7kAvTAIHv+;~ZeS=FXX%?Z++7aXvA=4>yk#MB!O<~# z*>r{~5pO4RLmuvA207n}qA^IA;D;}OUeMHx@Lf&K#A}<*Z-mFz66et9eZE@N?3eJ% z=L~$6Uc3q$afzk;a#ayW4?cww)wI7*-|9AT-YXhCJ*xdHRDYic#c1t35jt`ZTJoe6 z+P_KpzeVJ?iTnF^%{VYxw2i2~KPNb`s+X!+e+=qmGe*gpila*+jopZtvm+jl2uhv0E9|b8A6W%B!VCvGbLY1a; zU&9I~A=-@HRhr&l>q{6}rIrg99-JklWHi0bJxv^%`X<^FXQ+OkQY^yHA?4?ty?Whp zoXx$~-;xkBpaqb)Sq;35YKqYvJ|=4mH$ zkD>C=q77^w0|vNA<6(Od&0mUT?eA$Y^a8Pc{m4Ux!WBe_hPsBAw2wuApb7qL;$bg~13RC<(B)dcZ*%ClzW>zlZ4Dcdfd3jZ4y994x z)SaJ?v`q-7iWb5Tjc|mX-^I-Y%7LD{>NYNj7U5<5TZlHCiH3ND9K3gchuk!rx)cNv zVvCqbIx5x6AWyWBAwef$g@^)o^Tsg=E2nVSHHUvf(NRc5Q#Lh-Yv$&650KU9xu!{M zl=xo^CYc^RlU#)jrHk-A`)Qw5d*HJEw zWl)4AauxX$)(ISm(q2j#<7-)QS-IQdeSoH6>}{@e(+B8H`~5R3id5t@?UFk?$N!)C{^N-A=j#rh zAAa>f{{52Ud`*q%r;5fWc(YyRIMR_Ua$?qIjOuRW#$MYqbuaSUKKlH~kArq4uC}YD zKR~bEPN!H||>x-{8lQj%x6$cLFBRk&dvu-6QuFxff!H=Zw9;fqn29g!fK-Vw~j(G4o7 z)^zpUq~knwR3NrwP}E|T?e6GM76CYP_ll}UhZv}0F4rQ-25F{>bAY2OfGu7u^a(E0 zd5mm(4toi^D(j3+s{2K?>@7SH_|nl!VEnI-*Wx5!Q)`J5`OUrx>n@jK?o5g39O!0rR;Rwmu6qr);Vv+J9qoc`1E~uHUfAsj zy^{zWct%-0Od~I2l=T)WP(A!!yV9ky=Po-33~I>6?5%5**gYgw^$)!RC@>T-cwt2h z@_?nWUp`F)Aj*S z)}9|GQ-Ebg1h!bFrB`N#kHz@|3}s}*PHmAbG5tPjO|S5TzD>n?QL)>a!IAv<1jnz? zD2<-MT*>TV^gt1a^$9)T_ZmL+kR5pB1b{Kj9!|m-zYNAsKajz))5DrR9KnPJ&8mV< zZR{U185p{73Wgpeg<2rJ*QmM9LFE85`*Y*UxVq=sxo{=K8RHiaRt{=J&5XnK5Uo?= zb7OB@8_)H~CsxEid?*-sntO?e_Fo)Zd#-e5J&ZBqVFF)c%pbl%kc-Ut(|v9cgSC>K zhni=p&SBOkmPM;ddubUiVKPNI7TKP46P7-;SL@R%1mSG*-`)F*4kO8T#WQ=P3^r0O7Ub5}S7-UGe-}Ov6pkiwqW^)>?UGTc z^gO4CLnqlNfi(M|Q`{`^>y(IX^83~S0bP;v=jK2JvGJ<2AcG#TGP0VS&C;A8-XVB5 z`91zLO(P`|%if9yp2P}~!#9RDu$0ceo5nGus3H}sB!}2l{#+pTeWTyHNpra0j6@8vl@5&dEcnM9v1Zj6!^)#a zw_AMuNsB+)Kp3X)-4_4w$&(Wmefp01KWXuw()^#b`0ZP?U;?Ja>{uCE5L9V82B+2( zW+0uZyps-;h9J|VBgmGsFk2BV{`8$B`OPh>11pbifjbV+;zy8b>^X--RBq1*GbkYmAcRMartAYQmVUAP&h%|gT>6ZwegVX1mz80DLWPTAnk zL`0-|JV}quk7v-pthaycPRL&xe`1MEoFLLfvjr1VCYfjHWZp?+^*b_am63_x@7&*h zXjm(2!8Rq?+JffK8su5(Pty!~9fRD8;9Gpdn)HStdgg>%cE$;7BrhhVsUTI0#i$6L z?<%d3d-w!;Q7z#tyikr#E6}+q{(pwhGlXl$nXfZ_7vFCv2<0~>=evVSZZ9XuQz$DL zvzL^uP`-yO5V?bkd;~cQDa{h@m`jdPoZ!VT%+azi4kMxIdNH*HxD$lf6OD=l?hB_B zsx(eY=CrX;-|W?z>22!$5ek5^PW25G?S@v*DB+NuGa#s2{={krlPvhqQ83(~} zr;urU8V4&%gMep_TvXyPfz9A6M))?XRxiA47#Ef9Fs84E3jTpat5Y)k3F<|y13wogolY|_ zf-{-<5dnNeT6aX=NB@#)lsg|u6wU7ybG>2G&EYOGO}8{`V1EGX*R)RmoC;cOrEe&Z zgy4q5q?h88I738RmQZuO6U z+WzhD0nnD=g7faEf?;+xh{9=9=-Q^ub)oiAoXTOOwn$?$V%1;`x5h4_ocCN-ztUj! bhRfdnz-5b<@Ye9w>ueq~F1qw$KI;Dgks8qo literal 0 HcmV?d00001 diff --git a/algorithms/neat/genome/activations.py b/algorithms/neat/genome/activations.py new file mode 100644 index 0000000..16f0d52 --- /dev/null +++ b/algorithms/neat/genome/activations.py @@ -0,0 +1,138 @@ +import jax +import jax.numpy as jnp +from jax import jit + + +@jit +def sigmoid_act(z): + z = jnp.clip(z * 5, -60, 60) + return 1 / (1 + jnp.exp(-z)) + + +@jit +def tanh_act(z): + z = jnp.clip(z * 2.5, -60, 60) + return jnp.tanh(z) + + +@jit +def sin_act(z): + z = jnp.clip(z * 5, -60, 60) + return jnp.sin(z) + + +@jit +def gauss_act(z): + z = jnp.clip(z, -3.4, 3.4) + return jnp.exp(-5 * z ** 2) + + +@jit +def relu_act(z): + return jnp.maximum(z, 0) + + +@jit +def elu_act(z): + return jnp.where(z > 0, z, jnp.exp(z) - 1) + + +@jit +def lelu_act(z): + leaky = 0.005 + return jnp.where(z > 0, z, leaky * z) + + +@jit +def selu_act(z): + lam = 1.0507009873554804934193349852946 + alpha = 1.6732632423543772848170429916717 + return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1)) + + +@jit +def softplus_act(z): + z = jnp.clip(z * 5, -60, 60) + return 0.2 * jnp.log(1 + jnp.exp(z)) + + +@jit +def identity_act(z): + return z + + +@jit +def clamped_act(z): + return jnp.clip(z, -1, 1) + + +@jit +def inv_act(z): + return 1 / z + + +@jit +def log_act(z): + z = jnp.maximum(z, 1e-7) + return jnp.log(z) + + +@jit +def exp_act(z): + z = jnp.clip(z, -60, 60) + return jnp.exp(z) + + +@jit +def abs_act(z): + return jnp.abs(z) + + +@jit +def hat_act(z): + return jnp.maximum(0, 1 - jnp.abs(z)) + + +@jit +def square_act(z): + return z ** 2 + + +@jit +def cube_act(z): + return z ** 3 + + +ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act, + identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act] + +act_name2key = { + 'sigmoid': 0, + 'tanh': 1, + 'sin': 2, + 'gauss': 3, + 'relu': 4, + 'elu': 5, + 'lelu': 6, + 'selu': 7, + 'softplus': 8, + 'identity': 9, + 'clamped': 10, + 'inv': 11, + 'log': 12, + 'exp': 13, + 'abs': 14, + 'hat': 15, + 'square': 16, + 'cube': 17, +} + + +@jit +def act(idx, z): + idx = jnp.asarray(idx, dtype=jnp.int32) + # change idx from float to int + return jax.lax.switch(idx, ACT_TOTAL_LIST, z) + + +vectorized_act = jax.vmap(act, in_axes=(0, 0)) diff --git a/algorithms/neat/genome/aggregations.py b/algorithms/neat/genome/aggregations.py new file mode 100644 index 0000000..6cf172e --- /dev/null +++ b/algorithms/neat/genome/aggregations.py @@ -0,0 +1,109 @@ +""" +aggregations, two special case need to consider: +1. extra 0s +2. full of 0s +""" + +import jax +import jax.numpy as jnp +import numpy as np +from jax import jit + + +@jit +def sum_agg(z): + z = jnp.where(jnp.isnan(z), 0, z) + return jnp.sum(z, axis=0) + + +@jit +def product_agg(z): + z = jnp.where(jnp.isnan(z), 1, z) + return jnp.prod(z, axis=0) + + +@jit +def max_agg(z): + z = jnp.where(jnp.isnan(z), -jnp.inf, z) + return jnp.max(z, axis=0) + + +@jit +def min_agg(z): + z = jnp.where(jnp.isnan(z), jnp.inf, z) + return jnp.min(z, axis=0) + + +@jit +def maxabs_agg(z): + z = jnp.where(jnp.isnan(z), 0, z) + abs_z = jnp.abs(z) + max_abs_index = jnp.argmax(abs_z) + return z[max_abs_index] + + +@jit +def median_agg(z): + + non_zero_mask = ~jnp.isnan(z) + n = jnp.sum(non_zero_mask, axis=0) + + z = jnp.where(jnp.isnan(z), jnp.inf, z) + sorted_valid_values = jnp.sort(z) + + def _even_case(): + return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2 + + def _odd_case(): + return sorted_valid_values[n // 2] + + median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case) + + return median + + +@jit +def mean_agg(z): + non_zero_mask = ~jnp.isnan(z) + valid_values_sum = sum_agg(z) + valid_values_count = jnp.sum(non_zero_mask, axis=0) + mean_without_zeros = valid_values_sum / valid_values_count + return mean_without_zeros + + +AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg] + +agg_name2key = { + 'sum': 0, + 'product': 1, + 'max': 2, + 'min': 3, + 'maxabs': 4, + 'median': 5, + 'mean': 6, +} + + +@jit +def agg(idx, z): + idx = jnp.asarray(idx, dtype=jnp.int32) + + def full_zero(): + return 0. + + def not_full_zero(): + return jax.lax.switch(idx, AGG_TOTAL_LIST, z) + + return jax.lax.cond(jnp.all(z == 0.), full_zero, not_full_zero) + + +vectorized_agg = jax.vmap(agg, in_axes=(0, 0)) + +if __name__ == '__main__': + array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32) + for names in agg_name2key.keys(): + print(names, agg(agg_name2key[names], array)) + + array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32) + for names in agg_name2key.keys(): + print(names, agg(agg_name2key[names], array2)) diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py new file mode 100644 index 0000000..5e130d9 --- /dev/null +++ b/algorithms/neat/genome/crossover.py @@ -0,0 +1,151 @@ +from functools import partial +from typing import Tuple + +import jax +from jax import jit, vmap, Array +from jax import numpy as jnp + +# from .utils import flatten_connections, unflatten_connections +from algorithms.neat.genome.utils import flatten_connections, unflatten_connections + + +@vmap +def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array, + batch_connections2: Array) -> Tuple[Array, Array]: + """ + crossover a batch of genomes + :param randkeys: batches of random keys + :param batch_nodes1: + :param batch_connections1: + :param batch_nodes2: + :param batch_connections2: + :return: + """ + return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2) + + +@jit +def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ + -> Tuple[Array, Array]: + """ + use genome1 and genome2 to generate a new genome + notice that genome1 should have higher fitness than genome2 (genome1 is winner!) + :param randkey: + :param nodes1: + :param connections1: + :param nodes2: + :param connections2: + :return: + """ + randkey_1, randkey_2 = jax.random.split(randkey) + + # crossover nodes + keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + nodes2 = align_array(keys1, keys2, nodes2, 'node') + new_nodes = jnp.where(jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) + + # crossover connections + cons1 = flatten_connections(keys1, connections1) + cons2 = flatten_connections(keys2, connections2) + con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] + cons2 = align_array(con_keys1, con_keys2, cons2, 'connection') + new_cons = jnp.where(jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) + new_cons = unflatten_connections(len(keys1), new_cons) + + return new_nodes, new_cons + + +@partial(jit, static_argnames=['gene_type']) +def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: + """ + make ar2 align with ar1. + :param seq1: + :param seq2: + :param ar2: + :param gene_type: + :return: + align means to intersect part of ar2 will be at the same position as ar1, + non-intersect part of ar2 will be set to Nan + """ + seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :] + mask = (seq1 == seq2) & (~jnp.isnan(seq1)) + + if gene_type == 'connection': + mask = jnp.all(mask, axis=2) + + intersect_mask = mask.any(axis=1) + idx = jnp.arange(0, len(seq1)) + idx_fixed = jnp.dot(mask, idx) + + refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan) + + return refactor_ar2 + + +@jit +def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: + """ + crossover two genes + :param rand_key: + :param g1: + :param g2: + :return: + only gene with the same key will be crossover, thus don't need to consider change key + """ + r = jax.random.uniform(rand_key, shape=g1.shape) + return jnp.where(r > 0.5, g1, g2) + + +if __name__ == '__main__': + import numpy as np + + randkey = jax.random.PRNGKey(40) + nodes1 = np.array([ + [4, 1, 1, 1, 1], + [6, 2, 2, 2, 2], + [1, 3, 3, 3, 3], + [5, 4, 4, 4, 4], + [np.nan, np.nan, np.nan, np.nan, np.nan] + ]) + nodes2 = np.array([ + [4, 1.5, 1.5, 1.5, 1.5], + [7, 3.5, 3.5, 3.5, 3.5], + [5, 4.5, 4.5, 4.5, 4.5], + [np.nan, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan], + ]) + weights1 = np.array([ + [ + [1, 2, 3, 4., np.nan], + [5, np.nan, 7, 8, np.nan], + [9, 10, 11, 12, np.nan], + [13, 14, 15, 16, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan] + ], + [ + [0, 1, 0, 1, np.nan], + [0, np.nan, 0, 1, np.nan], + [0, 1, 0, 1, np.nan], + [0, 1, 0, 1, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan] + ] + ]) + weights2 = np.array([ + [ + [1.5, 2.5, 3.5, np.nan, np.nan], + [3.5, 4.5, 5.5, np.nan, np.nan], + [6.5, 7.5, 8.5, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan] + ], + [ + [1, 0, 1, np.nan, np.nan], + [1, 0, 1, np.nan, np.nan], + [1, 0, 1, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan] + ] + ]) + + res = crossover(randkey, nodes1, weights1, nodes2, weights2) + print(*res, sep='\n') diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py new file mode 100644 index 0000000..fd65a71 --- /dev/null +++ b/algorithms/neat/genome/distance.py @@ -0,0 +1,71 @@ +from functools import partial + +from jax import jit, vmap, Array +from jax import numpy as jnp + +from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis + + +@jit +def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array: + """ + Calculate the distance between two genomes. + nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg] + connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] + """ + + node_distance = gene_distance(nodes1, nodes2, 'node') + + # refactor connections + keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + cons1 = flatten_connections(keys1, connections1) + cons2 = flatten_connections(keys2, connections2) + + connection_distance = gene_distance(cons1, cons2, 'connection') + return node_distance + connection_distance + + +@partial(jit, static_argnames=["gene_type"]) +def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): + if gene_type == 'node': + keys1, keys2 = ar1[:, :1], ar2[:, :1] + else: # connection + keys1, keys2 = ar1[:, :2], ar2[:, :2] + + n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) + nodes = jnp.concatenate((ar1, ar2), axis=0) + sorted_nodes = nodes[n_sorted_indices] + fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] + + non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask) + if gene_type == 'node': + node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes) + else: # connection + node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes) + + node_distance = jnp.where(jnp.isnan(node_distance), 0, node_distance) + homologous_distance = jnp.sum(node_distance * n_intersect_mask[:-1]) + + gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1)) + gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1)) + + val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe + return val / jnp.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2) + + +@partial(vmap, in_axes=(0, 0)) +def homologous_node_distance(n1, n2): + d = 0 + d += jnp.abs(n1[1] - n2[1]) # bias + d += jnp.abs(n1[2] - n2[2]) # response + d += n1[3] != n2[3] # activation + d += n1[4] != n2[4] + return d + + +@partial(vmap, in_axes=(0, 0)) +def homologous_connection_distance(c1, c2): + d = 0 + d += jnp.abs(c1[2] - c2[2]) # weight + d += c1[3] != c2[3] # enable + return d diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py new file mode 100644 index 0000000..86e9116 --- /dev/null +++ b/algorithms/neat/genome/forward.py @@ -0,0 +1,171 @@ +from functools import partial + +import jax +from jax import Array, numpy as jnp +from jax import jit, vmap +from numpy.typing import NDArray + +from .aggregations import agg +from .activations import act +from .graph import topological_sort, batch_topological_sort, topological_sort_debug +from .utils import I_INT + + +def create_forward_function(nodes: NDArray, connections: NDArray, + N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False): + """ + create forward function for different situations + + :param nodes: shape (N, 5) or (pop_size, N, 5) + :param connections: shape (2, N, N) or (pop_size, 2, N, N) + :param N: + :param input_idx: + :param output_idx: + :param batch: using batch or not + :param debug: debug mode + :return: + """ + + if debug: + cal_seqs = topological_sort(nodes, connections) + return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx, + cal_seqs, nodes, connections) + + if nodes.ndim == 2: # single genome + cal_seqs = topological_sort(nodes, connections) + if not batch: + return lambda inputs: forward_single(inputs, N, input_idx, output_idx, + cal_seqs, nodes, connections) + else: + return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx, + cal_seqs, nodes, connections) + elif nodes.ndim == 3: # pop genome + pop_cal_seqs = batch_topological_sort(nodes, connections) + if not batch: + return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx, + pop_cal_seqs, nodes, connections) + else: + return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx, + pop_cal_seqs, nodes, connections) + else: + raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}") + + +# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx']) +@partial(jit, static_argnames=['N']) +def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, + cal_seqs: Array, nodes: Array, connections: Array) -> Array: + """ + jax forward for single input shaped (input_num, ) + nodes, connections are single genome + + :argument inputs: (input_num, ) + :argument N: int + :argument input_idx: (input_num, ) + :argument output_idx: (output_num, ) + :argument cal_seqs: (N, ) + :argument nodes: (N, 5) + :argument connections: (2, N, N) + + :return (output_num, ) + """ + ini_vals = jnp.full((N,), jnp.nan) + ini_vals = ini_vals.at[input_idx].set(inputs) + + def scan_body(carry, i): + def hit(): + ins = carry * connections[0, :, i] + z = agg(nodes[i, 4], ins) + z = z * nodes[i, 2] + nodes[i, 1] + z = act(nodes[i, 3], z) + + # for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals + new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z)) + return new_vals + + def miss(): + return carry + + return jax.lax.cond(i == I_INT, miss, hit), None + + vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs) + + return vals[output_idx] + + +def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections): + ini_vals = jnp.full((N,), jnp.nan) + ini_vals = ini_vals.at[input_idx].set(inputs) + vals = ini_vals + for i in cal_seqs: + if i == I_INT: + break + ins = vals * connections[0, :, i] + z = agg(nodes[i, 4], ins) + z = z * nodes[i, 2] + nodes[i, 1] + z = act(nodes[i, 3], z) + + # for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals + vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z)) + + return vals[output_idx] + + +@partial(vmap, in_axes=(0, None, None, None, None, None, None)) +def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, + cal_seqs: Array, nodes: Array, connections: Array) -> Array: + """ + jax forward for batch_inputs shaped (batch_size, input_num) + nodes, connections are single genome + + :argument batch_inputs: (batch_size, input_num) + :argument N: int + :argument input_idx: (input_num, ) + :argument output_idx: (output_num, ) + :argument cal_seqs: (N, ) + :argument nodes: (N, 5) + :argument connections: (2, N, N) + + :return (batch_size, output_num) + """ + return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) + + +@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) +def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, + pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: + """ + jax forward for single input shaped (input_num, ) + pop_nodes, pop_connections are population of genomes + + :argument inputs: (input_num, ) + :argument N: int + :argument input_idx: (input_num, ) + :argument output_idx: (output_num, ) + :argument pop_cal_seqs: (pop_size, N) + :argument pop_nodes: (pop_size, N, 5) + :argument pop_connections: (pop_size, 2, N, N) + + :return (pop_size, output_num) + """ + return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) + + +@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) +def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, + pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: + """ + jax forward for batch input shaped (batch, input_num) + pop_nodes, pop_connections are population of genomes + + :argument batch_inputs: (batch_size, input_num) + :argument N: int + :argument input_idx: (input_num, ) + :argument output_idx: (output_num, ) + :argument pop_cal_seqs: (pop_size, N) + :argument pop_nodes: (pop_size, N, 5) + :argument pop_connections: (pop_size, 2, N, N) + + :return (pop_size, batch_size, output_num) + """ + return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py new file mode 100644 index 0000000..0173b06 --- /dev/null +++ b/algorithms/neat/genome/genome.py @@ -0,0 +1,195 @@ +""" +Vectorization of genome representation. + +Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where: + +1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes +too large to be represented by the current value of N. +2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function +(act), and aggregation function (agg). +3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled +status. +Empty nodes or connections are represented using np.nan. + +""" +from typing import Tuple +from functools import partial + +import numpy as np +from numpy.typing import NDArray +from jax import numpy as jnp +from jax import jit +from jax import Array + +from algorithms.neat.genome.utils import fetch_first, fetch_last + +EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan]) + + +def create_initialize_function(config): + pop_size = config.neat.population.pop_size + N = config.basic.init_maximum_nodes + num_inputs = config.basic.num_inputs + num_outputs = config.basic.num_outputs + default_bias = config.neat.gene.bias.init_mean + default_response = config.neat.gene.response.init_mean + # default_act = config.neat.gene.activation.default + # default_agg = config.neat.gene.aggregation.default + default_act = 0 + default_agg = 0 + default_weight = config.neat.gene.weight.init_mean + return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response, + default_act, default_agg, default_weight) + + +def initialize_genomes(pop_size: int, + N: int, + num_inputs: int, num_outputs: int, + default_bias: float = 0.0, + default_response: float = 1.0, + default_act: int = 0, + default_agg: int = 0, + default_weight: float = 1.0) \ + -> Tuple[NDArray, NDArray, NDArray, NDArray]: + """ + Initialize genomes with default values. + + Args: + pop_size (int): Number of genomes to initialize. + N (int): Maximum number of nodes in the network. + num_inputs (int): Number of input nodes. + num_outputs (int): Number of output nodes. + default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0. + default_response (float, optional): Default response value for output nodes. Defaults to 1.0. + default_act (int, optional): Default activation function index for output nodes. Defaults to 1. + default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0. + default_weight (float, optional): Default weight value for connections. Defaults to 0.0. + + Raises: + AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N. + + Returns: + Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays. + """ + # Reserve one row for potential mutation adding an extra node + assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \ + f"{num_inputs} and output_size: {num_outputs}!" + + pop_nodes = np.full((pop_size, N, 5), np.nan) + pop_connections = np.full((pop_size, 2, N, N), np.nan) + input_idx = np.arange(num_inputs) + output_idx = np.arange(num_inputs, num_inputs + num_outputs) + + pop_nodes[:, input_idx, 0] = input_idx + pop_nodes[:, output_idx, 0] = output_idx + + pop_nodes[:, output_idx, 1] = default_bias + pop_nodes[:, output_idx, 2] = default_response + pop_nodes[:, output_idx, 3] = default_act + pop_nodes[:, output_idx, 4] = default_agg + + for i in input_idx: + for j in output_idx: + pop_connections[:, 0, i, j] = default_weight + pop_connections[:, 1, i, j] = 1 + + return pop_nodes, pop_connections, input_idx, output_idx + + +def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]: + """ + Expand the genome to accommodate more nodes. + :param pop_nodes: + :param pop_connections: + :param new_N: + :return: + """ + pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1] + + new_pop_nodes = np.full((pop_size, new_N, 5), np.nan) + new_pop_nodes[:, :old_N, :] = pop_nodes + + new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan) + new_pop_connections[:, :, :old_N, :old_N] = pop_connections + return new_pop_nodes, new_pop_connections + + +@jit +def add_node(new_node_key: int, nodes: Array, connections: Array, + bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]: + """ + add a new node to the genome. + """ + exist_keys = nodes[:, 0] + idx = fetch_first(jnp.isnan(exist_keys)) + nodes = nodes.at[idx].set(jnp.array([new_node_key, bias, response, act, agg])) + return nodes, connections + + +@jit +def delete_node(node_key: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: + """ + delete a node from the genome. only delete the node, regardless of connections. + """ + node_keys = nodes[:, 0] + idx = fetch_first(node_keys == node_key) + return delete_node_by_idx(idx, nodes, connections) + + +@jit +def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: + """ + delete a node from the genome. only delete the node, regardless of connections. + """ + node_keys = nodes[:, 0] + # move the last node to the deleted node's position + last_real_idx = fetch_last(~jnp.isnan(node_keys)) + nodes = nodes.at[idx].set(nodes[last_real_idx]) + nodes = nodes.at[last_real_idx].set(EMPTY_NODE) + return nodes, connections + + +@jit +def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array, + weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]: + """ + add a new connection to the genome. + """ + node_keys = nodes[:, 0] + from_idx = fetch_first(node_keys == from_node) + to_idx = fetch_first(node_keys == to_node) + return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled) + + +@jit +def add_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array, + weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]: + """ + add a new connection to the genome. + """ + connections = connections.at[:, from_idx, to_idx].set(jnp.array([weight, enabled])) + return nodes, connections + + +@jit +def delete_connection(from_node: int, to_node: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: + """ + delete a connection from the genome. + """ + node_keys = nodes[:, 0] + from_idx = fetch_first(node_keys == from_node) + to_idx = fetch_first(node_keys == to_node) + return delete_connection_by_idx(from_idx, to_idx, nodes, connections) + + +@jit +def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]: + """ + delete a connection from the genome. + """ + connections = connections.at[:, from_idx, to_idx].set(np.nan) + return nodes, connections + +# if __name__ == '__main__': +# pop_nodes, pop_connections, input_keys, output_keys = initialize_genomes(100, 5, 2, 1) +# print(pop_nodes, pop_connections) diff --git a/algorithms/neat/genome/graph.py b/algorithms/neat/genome/graph.py new file mode 100644 index 0000000..55f8c4d --- /dev/null +++ b/algorithms/neat/genome/graph.py @@ -0,0 +1,198 @@ +""" +Some graph algorithms implemented in jax. +Only used in feed-forward networks. +""" + +import jax +from jax import jit, vmap, Array +from jax import numpy as jnp + +# from .utils import fetch_first, I_INT +from algorithms.neat.genome.utils import fetch_first, I_INT + + +@jit +def topological_sort(nodes: Array, connections: Array) -> Array: + """ + a jit-able version of topological_sort! that's crazy! + :param nodes: nodes array + :param connections: connections array + :return: topological sorted sequence + + Example: + nodes = jnp.array([ + [0], + [1], + [2], + [3] + ]) + connections = jnp.array([ + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ], + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ] + ]) + + topological_sort(nodes, connections) -> [0, 1, 2, 3] + """ + connections_enable = connections[1, :, :] == 1 + in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0)) + res = jnp.full(in_degree.shape, I_INT) + idx = 0 + + def scan_body(carry, _): + res_, idx_, in_degree_ = carry + i = fetch_first(in_degree_ == 0.) + + def hit(): + # add to res and flag it is already in it + new_res = res_.at[idx_].set(i) + new_idx = idx_ + 1 + new_in_degree = in_degree_.at[i].set(-1) + + # decrease in_degree of all its children + children = connections_enable[i, :] + new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree) + return new_res, new_idx, new_in_degree + + def miss(): + return res_, idx_, in_degree_ + + return jax.lax.cond(i == I_INT, miss, hit), None + + scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0]) + res, _, _ = scan_res + + return res + + +# @jit +def topological_sort_debug(nodes: Array, connections: Array) -> Array: + connections_enable = connections[1, :, :] == 1 + in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0)) + res = jnp.full(in_degree.shape, I_INT) + idx = 0 + + for _ in range(in_degree.shape[0]): + i = fetch_first(in_degree == 0.) + if i == I_INT: + break + res = res.at[idx].set(i) + idx += 1 + in_degree = in_degree.at[i].set(-1) + children = connections_enable[i, :] + in_degree = jnp.where(children, in_degree - 1, in_degree) + + return res + + +@vmap +def batch_topological_sort(nodes: Array, connections: Array) -> Array: + """ + batch version of topological_sort + :param nodes: + :param connections: + :return: + """ + return topological_sort(nodes, connections) + + +@jit +def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array: + """ + Check whether a new connection (from_idx -> to_idx) will cause a cycle. + + :param nodes: JAX array + The array of nodes. + :param connections: JAX array + The array of connections. + :param from_idx: int + The index of the starting node. + :param to_idx: int + The index of the ending node. + :return: JAX array + An array indicating if there is a cycle caused by the new connection. + + Example: + nodes = jnp.array([ + [0], + [1], + [2], + [3] + ]) + connections = jnp.array([ + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ], + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ] + ]) + + check_cycles(nodes, connections, 3, 2) -> True + check_cycles(nodes, connections, 2, 3) -> False + check_cycles(nodes, connections, 0, 3) -> False + check_cycles(nodes, connections, 1, 0) -> False + """ + connections_enable = connections[1, :, :] == 1 + connections_enable = connections_enable.at[from_idx, to_idx].set(True) + nodes_visited = jnp.full(nodes.shape[0], False) + nodes_visited = nodes_visited.at[to_idx].set(True) + + def scan_body(visited, _): + new_visited = jnp.dot(visited, connections_enable) + new_visited = jnp.logical_or(visited, new_visited) + return new_visited, None + + nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0]) + + return nodes_visited[from_idx] + + +if __name__ == '__main__': + nodes = jnp.array([ + [0], + [1], + [2], + [3], + [jnp.nan] + ]) + connections = jnp.array([ + [ + [0, 0, 1, 0, jnp.nan], + [0, 0, 1, 1, jnp.nan], + [0, 0, 0, 1, jnp.nan], + [0, 0, 0, 0, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] + ], + [ + [0, 0, 1, 0, jnp.nan], + [0, 0, 1, 1, jnp.nan], + [0, 0, 0, 1, jnp.nan], + [0, 0, 0, 0, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] + ] + ] + ) + + print(topological_sort_debug(nodes, connections)) + print(topological_sort(nodes, connections)) + + print(check_cycles(nodes, connections, 3, 2)) + print(check_cycles(nodes, connections, 2, 3)) + print(check_cycles(nodes, connections, 0, 3)) + print(check_cycles(nodes, connections, 1, 0)) diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py new file mode 100644 index 0000000..3334a2d --- /dev/null +++ b/algorithms/neat/genome/mutate.py @@ -0,0 +1,538 @@ +from typing import Tuple +from functools import partial + +import jax +from jax import numpy as jnp +from jax import jit, vmap, Array + +from .utils import fetch_random, fetch_first, I_INT +from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx +from .graph import check_cycles + + +def create_mutate_function(config, input_keys, output_keys, batch: bool): + """ + create mutate function for different situations + :param output_keys: + :param input_keys: + :param config: + :param batch: mutate for population or not + :return: + """ + bias = config.neat.gene.bias + bias_default = bias.init_mean + bias_mean = bias.init_mean + bias_std = bias.init_stdev + bias_mutate_strength = bias.mutate_power + bias_mutate_rate = bias.mutate_rate + bias_replace_rate = bias.replace_rate + + response = config.neat.gene.response + response_default = response.init_mean + response_mean = response.init_mean + response_std = response.init_stdev + response_mutate_strength = response.mutate_power + response_mutate_rate = response.mutate_rate + response_replace_rate = response.replace_rate + + weight = config.neat.gene.weight + weight_mean = weight.init_mean + weight_std = weight.init_stdev + weight_mutate_strength = weight.mutate_power + weight_mutate_rate = weight.mutate_rate + weight_replace_rate = weight.replace_rate + + activation = config.neat.gene.activation + # act_default = activation.default + act_default = 0 + act_range = len(activation.options) + act_replace_rate = activation.mutate_rate + + aggregation = config.neat.gene.aggregation + # agg_default = aggregation.default + agg_default = 0 + agg_range = len(aggregation.options) + agg_replace_rate = aggregation.mutate_rate + + enabled = config.neat.gene.enabled + enabled_reverse_rate = enabled.mutate_rate + + genome = config.neat.genome + add_node_rate = genome.node_add_prob + delete_node_rate = genome.node_delete_prob + add_connection_rate = genome.conn_add_prob + delete_connection_rate = genome.conn_delete_prob + single_structure_mutate = genome.single_structural_mutation + + if not batch: + return lambda rand_key, nodes, connections, new_node_key: \ + mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys, + bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, + bias_replace_rate, response_default, response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate, + weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, + weight_replace_rate, act_default, act_range, act_replace_rate, + agg_default, agg_range, agg_replace_rate, enabled_reverse_rate, + add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate, + single_structure_mutate) + else: + batched_mutate = vmap(mutate, in_axes=(0, 0, 0, 0, *(None,) * 31)) + return lambda rand_keys, pop_nodes, pop_connections, new_node_keys: \ + batched_mutate(rand_keys, pop_nodes, pop_connections, new_node_keys, input_keys, output_keys, + bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, + bias_replace_rate, response_default, response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate, + weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, + weight_replace_rate, act_default, act_range, act_replace_rate, + agg_default, agg_range, agg_replace_rate, enabled_reverse_rate, + add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate, + single_structure_mutate) + + +@partial(jit, static_argnames=["single_structure_mutate"]) +def mutate(rand_key: Array, + nodes: Array, + connections: Array, + new_node_key: int, + input_keys: Array, + output_keys: Array, + bias_default: float = 0, + bias_mean: float = 0, + bias_std: float = 1, + bias_mutate_strength: float = 0.5, + bias_mutate_rate: float = 0.7, + bias_replace_rate: float = 0.1, + response_default: float = 1, + response_mean: float = 1., + response_std: float = 0., + response_mutate_strength: float = 0., + response_mutate_rate: float = 0., + response_replace_rate: float = 0., + weight_mean: float = 0., + weight_std: float = 1., + weight_mutate_strength: float = 0.5, + weight_mutate_rate: float = 0.7, + weight_replace_rate: float = 0.1, + act_default: int = 0, + act_range: int = 5, + act_replace_rate: float = 0.1, + agg_default: int = 0, + agg_range: int = 5, + agg_replace_rate: float = 0.1, + enabled_reverse_rate: float = 0.1, + add_node_rate: float = 0.2, + delete_node_rate: float = 0.2, + add_connection_rate: float = 0.4, + delete_connection_rate: float = 0.4, + single_structure_mutate: bool = True): + """ + :param output_keys: + :param input_keys: + :param agg_default: + :param act_default: + :param response_default: + :param bias_default: + :param rand_key: + :param nodes: (N, 5) + :param connections: (2, N, N) + :param new_node_key: + :param bias_mean: + :param bias_std: + :param bias_mutate_strength: + :param bias_mutate_rate: + :param bias_replace_rate: + :param response_mean: + :param response_std: + :param response_mutate_strength: + :param response_mutate_rate: + :param response_replace_rate: + :param weight_mean: + :param weight_std: + :param weight_mutate_strength: + :param weight_mutate_rate: + :param weight_replace_rate: + :param act_range: + :param act_replace_rate: + :param agg_range: + :param agg_replace_rate: + :param enabled_reverse_rate: + :param add_node_rate: + :param delete_node_rate: + :param add_connection_rate: + :param delete_connection_rate: + :param single_structure_mutate: a genome is structurally mutate at most once + :return: + """ + + # mutate_structure + def nothing(rk, n, c): + return n, c + + def m_add_node(rk, n, c): + return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default) + + def m_delete_node(rk, n, c): + return mutate_delete_node(rk, n, c, input_keys, output_keys) + + def m_add_connection(rk, n, c): + return mutate_add_connection(rk, n, c, input_keys, output_keys) + + def m_delete_connection(rk, n, c): + return mutate_delete_connection(rk, n, c) + + mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection] + + if single_structure_mutate: + r1, r2, rand_key = jax.random.split(rand_key, 3) + d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate) + + # shorten variable names for beauty + anr, dnr = add_node_rate / d, delete_node_rate / d + acr, dcr = add_connection_rate / d, delete_connection_rate / d + + r = rand(r1) + branch = 0 + branch = jnp.where(r <= anr, 1, branch) + branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch) + branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch) + branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch) + nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections)) + else: + r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) + + # mutate add node + aux_nodes, aux_connections = m_add_node(r1, nodes, connections) + nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes) + connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections) + + # mutate delete node + aux_nodes, aux_connections = m_delete_node(r2, nodes, connections) + nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes) + connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections) + + # mutate add connection + aux_nodes, aux_connections = m_add_connection(r3, nodes, connections) + nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes) + connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections) + + # mutate delete connection + aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections) + nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes) + connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections) + + nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength, + bias_mutate_rate, bias_replace_rate, response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate, + weight_mean, weight_std, weight_mutate_strength, + weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range, + agg_replace_rate, enabled_reverse_rate) + + return nodes, connections + + +@jit +def mutate_values(rand_key: Array, + nodes: Array, + connections: Array, + bias_mean: float = 0, + bias_std: float = 1, + bias_mutate_strength: float = 0.5, + bias_mutate_rate: float = 0.7, + bias_replace_rate: float = 0.1, + response_mean: float = 1., + response_std: float = 0., + response_mutate_strength: float = 0., + response_mutate_rate: float = 0., + response_replace_rate: float = 0., + weight_mean: float = 0., + weight_std: float = 1., + weight_mutate_strength: float = 0.5, + weight_mutate_rate: float = 0.7, + weight_replace_rate: float = 0.1, + act_range: int = 5, + act_replace_rate: float = 0.1, + agg_range: int = 5, + agg_replace_rate: float = 0.1, + enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]: + """ + Mutate values of nodes and connections. + + Args: + rand_key: A random key for generating random values. + nodes: A 2D array representing nodes. + connections: A 3D array representing connections. + bias_mean: Mean of the bias values. + bias_std: Standard deviation of the bias values. + bias_mutate_strength: Strength of the bias mutation. + bias_mutate_rate: Rate of the bias mutation. + bias_replace_rate: Rate of the bias replacement. + response_mean: Mean of the response values. + response_std: Standard deviation of the response values. + response_mutate_strength: Strength of the response mutation. + response_mutate_rate: Rate of the response mutation. + response_replace_rate: Rate of the response replacement. + weight_mean: Mean of the weight values. + weight_std: Standard deviation of the weight values. + weight_mutate_strength: Strength of the weight mutation. + weight_mutate_rate: Rate of the weight mutation. + weight_replace_rate: Rate of the weight replacement. + act_range: Range of the activation function values. + act_replace_rate: Rate of the activation function replacement. + agg_range: Range of the aggregation function values. + agg_replace_rate: Rate of the aggregation function replacement. + enabled_reverse_rate: Rate of reversing enabled state of connections. + + Returns: + A tuple containing mutated nodes and connections. + """ + + k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6) + bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std, + bias_mutate_strength, bias_mutate_rate, bias_replace_rate) + response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate) + weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std, + weight_mutate_strength, weight_mutate_rate, weight_replace_rate) + act_new = mutate_int_values(k4, nodes[:, 3], act_range, act_replace_rate) + agg_new = mutate_int_values(k5, nodes[:, 4], agg_range, agg_replace_rate) + + # refactor enabled + r = jax.random.uniform(rand_key, connections[1, :, :].shape) + enabled_new = connections[1, :, :] == 1 + enabled_new = jnp.where(r < enabled_reverse_rate, ~enabled_new, enabled_new) + enabled_new = jnp.where(~jnp.isnan(connections[0, :, :]), enabled_new, jnp.nan) + + nodes = nodes.at[:, 1].set(bias_new) + nodes = nodes.at[:, 2].set(response_new) + nodes = nodes.at[:, 3].set(act_new) + nodes = nodes.at[:, 4].set(agg_new) + connections = connections.at[0, :, :].set(weight_new) + connections = connections.at[1, :, :].set(enabled_new) + return nodes, connections + + +@jit +def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, + mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: + """ + Mutate float values of a given array. + + Args: + rand_key: A random key for generating random values. + old_vals: A 1D array of float values to be mutated. + mean: Mean of the values. + std: Standard deviation of the values. + mutate_strength: Strength of the mutation. + mutate_rate: Rate of the mutation. + replace_rate: Rate of the replacement. + + Returns: + A mutated 1D array of float values. + """ + k1, k2, k3, rand_key = jax.random.split(rand_key, num=4) + noise = jax.random.normal(k1, old_vals.shape) * mutate_strength + replace = jax.random.normal(k2, old_vals.shape) * std + mean + r = jax.random.uniform(k3, old_vals.shape) + new_vals = old_vals + new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals) + new_vals = jnp.where( + jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate), + replace, + new_vals + ) + new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan) + return new_vals + + +@jit +def mutate_int_values(rand_key: Array, old_vals: Array, range: int, replace_rate: float) -> Array: + """ + Mutate integer values (act, agg) of a given array. + + Args: + rand_key: A random key for generating random values. + old_vals: A 1D array of integer values to be mutated. + range: Range of the integer values. + replace_rate: Rate of the replacement. + + Returns: + A mutated 1D array of integer values. + """ + k1, k2, rand_key = jax.random.split(rand_key, num=3) + replace_val = jax.random.randint(k1, old_vals.shape, 0, range) + r = jax.random.uniform(k2, old_vals.shape) + new_vals = old_vals + new_vals = jnp.where(r < replace_rate, replace_val, new_vals) + new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan) + return new_vals + + +@jit +def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array, + default_bias: float = 0, default_response: float = 1, + default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: + """ + Randomly add a new node from splitting a connection. + :param rand_key: + :param new_node_key: + :param nodes: + :param connections: + :param default_bias: + :param default_response: + :param default_act: + :param default_agg: + :return: + """ + # randomly choose a connection + from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections) + + # disable the connection + connections = connections.at[1, from_idx, to_idx].set(False) + + # add a new node + nodes, connections = add_node(new_node_key, nodes, connections, + bias=default_bias, response=default_response, act=default_act, agg=default_agg) + new_idx = fetch_first(nodes[:, 0] == new_node_key) + + # add two new connections + weight = connections[0, from_idx, to_idx] + nodes, connections = add_connection_by_idx(from_idx, new_idx, nodes, connections, weight=0, enabled=True) + nodes, connections = add_connection_by_idx(new_idx, to_idx, nodes, connections, weight=weight, enabled=True) + + return nodes, connections + + +@jit +def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, + input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: + """ + Randomly delete a node. Input and output nodes are not allowed to be deleted. + :param rand_key: + :param nodes: + :param connections: + :param input_keys: + :param output_keys: + :return: + """ + # randomly choose a node + node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys, + allow_input_keys=False, allow_output_keys=False) + + # delete the node + aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections) + + # delete connections + aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan) + aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan) + + # check node_key valid + nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node + connections = jnp.where(jnp.isnan(node_key), connections, aux_connections) + + return nodes, connections + + +@jit +def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, + input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: + """ + Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks, + cycles are not allowed. + :param rand_key: + :param nodes: + :param connections: + :param input_keys: + :param output_keys: + :return: + """ + # randomly choose two nodes + k1, k2 = jax.random.split(rand_key, num=2) + from_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys, + allow_input_keys=True, allow_output_keys=True) + to_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys, + allow_input_keys=False, allow_output_keys=True) + + def successful(): + new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections) + return new_nodes, new_connections + + def already_exist(): + new_connections = connections.at[1, from_idx, to_idx].set(True) + return nodes, new_connections + + def cycle(): + return nodes, connections + + is_already_exist = ~jnp.isnan(connections[0, from_idx, to_idx]) + is_cycle = check_cycles(nodes, connections, from_idx, to_idx) + + choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) + nodes, connections = jax.lax.switch(choice, [already_exist, cycle, successful]) + return nodes, connections + + +@jit +def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): + """ + Randomly delete a connection. + :param rand_key: + :param nodes: + :param connections: + :return: + """ + # randomly choose a connection + from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections) + nodes, connections = delete_connection_by_idx(from_idx, to_idx, nodes, connections) + return nodes, connections + + +@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) +def choice_node_key(rand_key: Array, nodes: Array, + input_keys: Array, output_keys: Array, + allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: + """ + Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. + :param rand_key: + :param nodes: + :param input_keys: + :param output_keys: + :param allow_input_keys: + :param allow_output_keys: + :return: return its key and position(idx) + """ + + node_keys = nodes[:, 0] + mask = ~jnp.isnan(node_keys) + + if not allow_input_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys)) + + if not allow_output_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys)) + + idx = fetch_random(rand_key, mask) + key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) + return key, idx + + +@jit +def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]: + """ + Randomly choose a connection key from the given connections. + :param rand_key: + :param nodes: + :param connection: + :return: from_key, to_key, from_idx, to_idx + """ + k1, k2 = jax.random.split(rand_key, num=2) + has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1) + from_idx = fetch_random(k1, has_connections_row) + col = connection[0, from_idx, :] + to_idx = fetch_random(k2, ~jnp.isnan(col)) + from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0] + return from_key, to_key, from_idx, to_idx + + +@jit +def rand(rand_key): + return jax.random.uniform(rand_key, ()) diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py new file mode 100644 index 0000000..57eef00 --- /dev/null +++ b/algorithms/neat/genome/utils.py @@ -0,0 +1,134 @@ +from functools import partial +from typing import Tuple + +import jax +from jax import numpy as jnp, Array +from jax import jit + +I_INT = jnp.iinfo(jnp.int32).max # infinite int + + +@jit +def flatten_connections(keys, connections): + """ + flatten the (2, N, N) connections to (N * N, 4) + :param keys: + :param connections: + :return: + the first two columns are the index of the node + the 3rd column is the weight, and the 4th column is the enabled status + """ + indices_x, indices_y = jnp.meshgrid(keys, keys, indexing='ij') + indices = jnp.stack((indices_x, indices_y), axis=-1).reshape(-1, 2) + + # make (2, N, N) to (N, N, 2) + con = jnp.transpose(connections, (1, 2, 0)) + # make (N, N, 2) to (N * N, 2) + con = jnp.reshape(con, (-1, 2)) + + con = jnp.concatenate((indices, con), axis=1) + return con + + +@partial(jit, static_argnames=['N']) +def unflatten_connections(N, cons): + """ + restore the (N * N, 4) connections to (2, N, N) + :param N: + :param cons: + :return: + """ + cons = cons[:, 2:] # remove the indices + unflatten_cons = jnp.moveaxis(cons.reshape(N, N, 2), -1, 0) + return unflatten_cons + + +@jit +def set_operation_analysis(ar1: Array, ar2: Array) -> Tuple[Array, Array, Array]: + """ + Analyze the intersection and union of two arrays by returning their sorted concatenation indices, + intersection mask, and union mask. + + :param ar1: JAX array of shape (N, M) + First input array. Should have the same shape as ar2. + :param ar2: JAX array of shape (N, M) + Second input array. Should have the same shape as ar1. + :return: tuple of 3 arrays + - sorted_indices: Indices that would sort the concatenation of ar1 and ar2. + - intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2 + in the sorted concatenation. + - union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2 + in the sorted concatenation. + + Examples: + a = jnp.array([[1, 2], [3, 4], [5, 6]]) + b = jnp.array([[1, 2], [7, 8], [9, 10]]) + + sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b) + + sorted_indices -> array([0, 1, 2, 3, 4, 5]) + intersect_mask -> array([True, False, False, False, False, False]) + union_mask -> array([False, True, True, True, True, True]) + """ + ar = jnp.concatenate((ar1, ar2), axis=0) + sorted_indices = jnp.lexsort(ar.T[::-1]) + aux = ar[sorted_indices] + aux = jnp.concatenate((aux, jnp.full((1, ar1.shape[1]), jnp.nan)), axis=0) + nan_mask = jnp.any(jnp.isnan(aux), axis=1) + + fr, sr = aux[:-1], aux[1:] # first row, second row + intersect_mask = jnp.all(fr == sr, axis=1) & ~nan_mask[:-1] + union_mask = jnp.any(fr != sr, axis=1) & ~nan_mask[:-1] + return sorted_indices, intersect_mask, union_mask + + +@jit +def fetch_first(mask, default=I_INT) -> Array: + """ + fetch the first True index + :param mask: array of bool + :param default: the default value if no element satisfying the condition + :return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT + example: + >>> a = jnp.array([1, 2, 3, 4, 5]) + >>> fetch_first(a > 3) + 3 + >>> fetch_first(a > 30) + I_INT + """ + idx = jnp.argmax(mask) + return jnp.where(mask[idx], idx, default) + + +@jit +def fetch_last(mask, default=I_INT) -> Array: + """ + similar to fetch_first, but fetch the last True index + """ + reversed_idx = fetch_first(mask[::-1], default) + return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1) + + +@jit +def fetch_random(rand_key, mask, default=I_INT) -> Array: + """ + similar to fetch_first, but fetch a random True index + """ + true_cnt = jnp.sum(mask) + cumsum = jnp.cumsum(mask) + target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) + return fetch_first(cumsum >= target, default) + + +if __name__ == '__main__': + a = jnp.array([1, 2, 3, 4, 5]) + print(fetch_first(a > 3)) + print(fetch_first(a > 30)) + + print(fetch_last(a > 3)) + print(fetch_last(a > 30)) + + rand_key = jax.random.PRNGKey(0) + for _ in range(100): + rand_key, _ = jax.random.split(rand_key) + print(fetch_random(rand_key, a > 0)) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py new file mode 100644 index 0000000..83ec44f --- /dev/null +++ b/algorithms/neat/pipeline.py @@ -0,0 +1,41 @@ +import jax + +from .species import SpeciesController +from .genome import create_initialize_function, create_mutate_function, create_forward_function + + +class Pipeline: + """ + Neat algorithm pipeline. + """ + + def __init__(self, config): + self.config = config + self.N = config.basic.init_maximum_nodes + + self.species_controller = SpeciesController(config) + self.initialize_func = create_initialize_function(config) + self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() + self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True) + + self.generation = 0 + + self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) + + def ask(self, batch: bool): + """ + Create a forward function for the population. + :param batch: + :return: + Algorithm gives the population a forward function, then environment gives back the fitnesses. + """ + func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx, + batch=batch) + return func + + def tell(self, fitnesses): + self.generation += 1 + print(type(fitnesses), fitnesses) + self.species_controller.update_species_fitnesses(fitnesses) + + diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py new file mode 100644 index 0000000..b73f169 --- /dev/null +++ b/algorithms/neat/species.py @@ -0,0 +1,190 @@ +from typing import List, Tuple, Dict +from itertools import count + +import jax +import numpy as np +from numpy.typing import NDArray +from .genome import distance + + +class Species(object): + + def __init__(self, key, generation): + self.key = key + self.created = generation + self.last_improved = generation + self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections) + self.members: List[int] = [] # idx in pop_nodes, pop_connections + self.fitness = None + self.member_fitnesses = None + self.adjusted_fitness = None + self.fitness_history: List[float] = [] + + def update(self, representative, members): + self.representative = representative + self.members = members + + def get_fitnesses(self, fitnesses): + return [fitnesses[m] for m in self.members] + + +class SpeciesController: + """ + A class to control the species + """ + + def __init__(self, config): + self.config = config + self.compatibility_threshold = self.config.neat.species.compatibility_threshold + self.species_elitism = self.config.neat.species.species_elitism + self.max_stagnation = self.config.neat.species.max_stagnation + + self.species_idxer = count(0) + self.species: Dict[int, Species] = {} # species_id -> species + self.genome_to_species: Dict[int, int] = {} + + self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many + # self.o2o_distance_func = np_distance # one to one + self.o2o_distance_func = distance + + def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None: + """ + :param pop_nodes: + :param pop_connections: + :param generation: use to flag the created time of new species + :return: + """ + unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool) + previous_species_list = list(self.species.keys()) + + # Find the best representatives for each existing species. + new_representatives = {} + new_members = {} + + for sid, species in self.species.items(): + # calculate the distance between the representative and the population + r_nodes, r_connections = species.representative + distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections) + distances = jax.device_get(distances) # fetch the data from gpu + min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance + + new_representatives[sid] = min_idx + new_members[sid] = [min_idx] + unspeciated[min_idx] = False + + # Partition population into species based on genetic similarity. + + # First, fast match the population to previous species + rid_list = [new_representatives[sid] for sid in previous_species_list] + res_pop_distance = [ + jax.device_get( + [ + self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) + for rid in rid_list + ] + ) + ] + pop_res_distance = np.stack(res_pop_distance, axis=0).T + for i in range(pop_res_distance.shape[0]): + if not unspeciated[i]: + continue + min_idx = np.argmin(pop_res_distance[i]) + min_val = pop_res_distance[i, min_idx] + if min_val <= self.compatibility_threshold: + species_id = previous_species_list[min_idx] + new_members[species_id].append(i) + unspeciated[i] = False + + # Second, slowly match the lonely population to new-created species. + # lonely genome is proved to be not compatible with any previous species, so they only need to be compared with + # the new representatives. + new_species_list = [] + for i in range(pop_nodes.shape[0]): + if not unspeciated[i]: + continue + unspeciated[i] = False + if len(new_representatives) != 0: + rid = [new_representatives[sid] for sid in new_representatives] # the representatives of new species + distances = [ + self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) + for r in rid + ] + distances = np.array(distances) + min_idx = np.argmin(distances) + min_val = distances[min_idx] + if min_val <= self.compatibility_threshold: + species_id = new_species_list[min_idx] + new_members[species_id].append(i) + continue + # create a new species + species_id = next(self.species_idxer) + new_species_list.append(species_id) + new_representatives[species_id] = i + new_members[species_id] = [i] + + assert np.all(~unspeciated) + # Update species collection based on new speciation. + self.genome_to_species = {} + for sid, rid in new_representatives.items(): + s = self.species.get(sid) + if s is None: + s = Species(sid, generation) + self.species[sid] = s + + members = new_members[sid] + for gid in members: + self.genome_to_species[gid] = sid + + s.update((pop_nodes[rid], pop_connections[rid]), members) + + def update_species_fitnesses(self, fitnesses): + """ + update the fitness of each species + :param fitnesses: + :return: + """ + for sid, s in self.species.items(): + # TODO: here use mean to measure the fitness of a species, but it may be other functions + s.member_fitnesses = s.get_fitnesses(fitnesses) + s.fitness = np.mean(s.member_fitnesses) + s.fitness_history.append(s.fitness) + s.adjusted_fitness = None + + def stagnation(self, generation): + """ + code modified from neat-python! + :param generation: + :return: whether the species is stagnated + """ + species_data = [] + for sid, s in self.species.items(): + if s.fitness_history: + prev_fitness = max(s.fitness_history) + else: + prev_fitness = float('-inf') + + if prev_fitness is None or s.fitness > prev_fitness: + s.last_improved = generation + + species_data.append((sid, s)) + + # Sort in descending fitness order. + species_data.sort(key=lambda x: x[1].fitness, reverse=True) + + result = [] + for idx, (sid, s) in enumerate(species_data): + + if idx < self.species_elitism: # elitism species never stagnate! + is_stagnant = False + else: + stagnant_time = generation - s.last_improved + is_stagnant = stagnant_time > self.max_stagnation + + result.append((sid, s, is_stagnant)) + return result + + +def find_min_with_mask(arr: NDArray, mask: NDArray) -> int: + masked_arr = np.where(mask, arr, np.inf) + min_idx = np.argmin(masked_arr) + return min_idx diff --git a/algorithms/neat/stagnation.py b/algorithms/neat/stagnation.py new file mode 100644 index 0000000..5b0ce3d --- /dev/null +++ b/algorithms/neat/stagnation.py @@ -0,0 +1,62 @@ +""" +Code modified from NEAT-Python library +Keeps track of whether species are making progress and helps remove those which are not. +""" + + +class Stagnation: + """Keeps track of whether species are making progress and helps remove ones that are not.""" + + def __init__(self, config): + self.config = config + + def update(self, species_set, generation): + """ + Required interface method. Updates species fitness history information, + checking for ones that have not improved in max_stagnation generations, + and - unless it would result in the number of species dropping below the configured + species_elitism parameter if they were removed, + in which case the highest-fitness species are spared - + returns a list with stagnant species marked for removal. + """ + species_data = [] + for sid, s in species_set.species.items(): + if s.fitness_history: + prev_fitness = max(s.fitness_history) + else: + prev_fitness = float('-inf') + + s.fitness = max(s.get_fitnesses()) + s.fitness_history.append(s.fitness) + s.adjusted_fitness = None + if prev_fitness is None or s.fitness > prev_fitness: + s.last_improved = generation + + species_data.append((sid, s)) + + # Sort in ascending fitness order. + species_data.sort(key=lambda x: x[1].fitness) + + result = [] + species_fitnesses = [] + num_non_stagnant = len(species_data) + for idx, (sid, s) in enumerate(species_data): + # Override stagnant state if marking this species as stagnant would + # result in the total number of species dropping below the limit. + # Because species are in ascending fitness order, less fit species + # will be marked as stagnant first. + stagnant_time = generation - s.last_improved + is_stagnant = False + if num_non_stagnant > self.config.stagnation.species_elitism: + is_stagnant = stagnant_time >= self.config.stagnation.max_stagnation + + if (len(species_data) - idx) <= self.config.stagnation.species_elitism: + is_stagnant = False + + if is_stagnant: + num_non_stagnant -= 1 + + result.append((sid, s, is_stagnant)) + species_fitnesses.append(s.fitness) + + return result diff --git a/algorithms/numpy/__init__.py b/algorithms/numpy/__init__.py new file mode 100644 index 0000000..0c1d0ac --- /dev/null +++ b/algorithms/numpy/__init__.py @@ -0,0 +1,5 @@ +""" +numpy version of functions in genome +""" +from .distance import distance +from .utils import * \ No newline at end of file diff --git a/algorithms/numpy/distance.py b/algorithms/numpy/distance.py new file mode 100644 index 0000000..e56f2ff --- /dev/null +++ b/algorithms/numpy/distance.py @@ -0,0 +1,58 @@ +import numpy as np + +from .utils import flatten_connections, set_operation_analysis + + +def distance(nodes1, connections1, nodes2, connections2): + node_distance = gene_distance(nodes1, nodes2, 'node') + + # refactor connections + keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + cons1 = flatten_connections(keys1, connections1) + cons2 = flatten_connections(keys2, connections2) + + connection_distance = gene_distance(cons1, cons2, 'connection') + return node_distance + connection_distance + + +def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): + if gene_type == 'node': + keys1, keys2 = ar1[:, :1], ar2[:, :1] + else: # connection + keys1, keys2 = ar1[:, :2], ar2[:, :2] + + n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) + nodes = np.concatenate((ar1, ar2), axis=0) + sorted_nodes = nodes[n_sorted_indices] + fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] + + non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask) + if gene_type == 'node': + node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes) + else: # connection + node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes) + + node_distance = np.where(np.isnan(node_distance), 0, node_distance) + homologous_distance = np.sum(node_distance * n_intersect_mask[:-1]) + + gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1)) + gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1)) + + val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe + return val / np.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2) + + +def homologous_node_distance(n1, n2): + d = 0 + d += np.abs(n1[:, 1] - n2[:, 1]) # bias + d += np.abs(n1[:, 2] - n2[:, 2]) # response + d += n1[:, 3] != n2[:, 3] # activation + d += n1[:, 4] != n2[:, 4] + return d + + +def homologous_connection_distance(c1, c2): + d = 0 + d += np.abs(c1[:, 2] - c2[:, 2]) # weight + d += c1[:, 3] != c2[:, 3] # enable + return d diff --git a/algorithms/numpy/utils.py b/algorithms/numpy/utils.py new file mode 100644 index 0000000..57119a3 --- /dev/null +++ b/algorithms/numpy/utils.py @@ -0,0 +1,55 @@ +import numpy as np + +I_INT = np.iinfo(np.int32).max # infinite int + + +def flatten_connections(keys, connections): + indices_x, indices_y = np.meshgrid(keys, keys, indexing='ij') + indices = np.stack((indices_x, indices_y), axis=-1).reshape(-1, 2) + + # make (2, N, N) to (N, N, 2) + con = np.transpose(connections, (1, 2, 0)) + # make (N, N, 2) to (N * N, 2) + con = np.reshape(con, (-1, 2)) + + con = np.concatenate((indices, con), axis=1) + return con + + +def unflatten_connections(N, cons): + cons = cons[:, 2:] # remove the indices + unflatten_cons = np.moveaxis(cons.reshape(N, N, 2), -1, 0) + return unflatten_cons + + +def set_operation_analysis(ar1, ar2): + ar = np.concatenate((ar1, ar2), axis=0) + sorted_indices = np.lexsort(ar.T[::-1]) + aux = ar[sorted_indices] + aux = np.concatenate((aux, np.full((1, ar1.shape[1]), np.nan)), axis=0) + nan_mask = np.any(np.isnan(aux), axis=1) + + fr, sr = aux[:-1], aux[1:] # first row, second row + intersect_mask = np.all(fr == sr, axis=1) & ~nan_mask[:-1] + union_mask = np.any(fr != sr, axis=1) & ~nan_mask[:-1] + return sorted_indices, intersect_mask, union_mask + + +def fetch_first(mask, default=I_INT): + idx = np.argmax(mask) + return np.where(mask[idx], idx, default) + + +def fetch_last(mask, default=I_INT): + reversed_idx = fetch_first(mask[::-1], default) + return np.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1) + + +def fetch_random(rand_key, mask, default=I_INT): + """ + similar to fetch_first, but fetch a random True index + """ + true_cnt = np.sum(mask) + cumsum = np.cumsum(mask) + target = np.random.randint(rand_key, shape=(), minval=0, maxval=true_cnt + 1) + return fetch_first(cumsum >= target, default) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/genome_test.py b/examples/genome_test.py new file mode 100644 index 0000000..6c00865 --- /dev/null +++ b/examples/genome_test.py @@ -0,0 +1,71 @@ +import time + +import jax.random + +from utils import Configer +from algorithms.neat.genome.genome import * + +from algorithms.neat.species import SpeciesController +from algorithms.neat.genome.forward import create_forward_function +from algorithms.neat.genome.mutate import create_mutate_function + +if __name__ == '__main__': + N = 10 + pop_nodes, pop_connections, input_idx, output_idx = initialize_genomes(10000, N, 2, 1, + default_act=9, default_agg=0) + inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + # forward = create_forward_function(pop_nodes, pop_connections, 5, input_idx, output_idx, batch=True) + nodes, connections = pop_nodes[0], pop_connections[0] + + forward = create_forward_function(pop_nodes, pop_connections, N, input_idx, output_idx, batch=True) + out = forward(inputs) + print(out.shape) + print(out) + + config = Configer.load_config() + s_c = SpeciesController(config.neat) + s_c.speciate(pop_nodes, pop_connections, 0) + s_c.speciate(pop_nodes, pop_connections, 0) + print(s_c.genome_to_species) + + start = time.time() + for i in range(100): + print(i) + s_c.speciate(pop_nodes, pop_connections, i) + print(time.time() - start) + + seed = jax.random.PRNGKey(42) + mutate_func = create_mutate_function(config, input_idx, output_idx, batch=False) + print(nodes, connections, sep='\n') + print(*mutate_func(seed, nodes, connections, 100), sep='\n') + + randseeds = jax.random.split(seed, 10000) + new_node_keys = jax.random.randint(randseeds[0], minval=0, maxval=10000, shape=(10000,)) + batch_mutate_func = create_mutate_function(config, input_idx, output_idx, batch=True) + pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys) + print(pop_nodes, pop_connections, sep='\n') + + start = time.time() + for i in range(100): + print(i) + pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys) + print(time.time() - start) + + print(nodes, connections, sep='\n') + nodes, connections = add_node(6, nodes, connections) + nodes, connections = add_node(7, nodes, connections) + print(nodes, connections, sep='\n') + + nodes, connections = add_connection(6, 7, nodes, connections) + nodes, connections = add_connection(0, 7, nodes, connections) + nodes, connections = add_connection(1, 7, nodes, connections) + print(nodes, connections, sep='\n') + + nodes, connections = delete_connection(6, 7, nodes, connections) + print(nodes, connections, sep='\n') + + nodes, connections = delete_node(6, nodes, connections) + print(nodes, connections, sep='\n') + + nodes, connections = delete_node(7, nodes, connections) + print(nodes, connections, sep='\n') diff --git a/examples/jax_playground.py b/examples/jax_playground.py new file mode 100644 index 0000000..98c96b9 --- /dev/null +++ b/examples/jax_playground.py @@ -0,0 +1,37 @@ +import jax +import jax.numpy as jnp +import numpy as np +from jax import random +from jax import vmap, jit + + +def plus1(x): + return x + 1 + + +def minus1(x): + return x - 1 + + +def func(rand_key, x): + r = jax.random.uniform(rand_key, shape=()) + return jax.lax.cond(r > 0.5, plus1, minus1, x) + + +def func2(rand_key): + r = jax.random.uniform(rand_key, ()) + if r < 0.3: + return 1 + elif r < 0.5: + return 2 + else: + return 3 + + + +key = random.PRNGKey(0) +print(func(key, 0)) + +batch_func = vmap(jit(func)) +keys = random.split(key, 100) +print(batch_func(keys, jnp.zeros(100))) \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py new file mode 100644 index 0000000..abd384b --- /dev/null +++ b/examples/xor.py @@ -0,0 +1,40 @@ +from typing import Callable, List + +import jax +import numpy as np + +from utils import Configer +from algorithms.neat import Pipeline + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) +xor_outputs = np.array([[0], [1], [1], [0]]) + + +def evaluate(forward_func: Callable) -> List[float]: + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + outs = jax.device_get(outs) + fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2)) + return fitnesses.tolist() # returns a list + + +def main(): + config = Configer.load_config() + pipeline = Pipeline(config) + forward_func = pipeline.ask(batch=True) + fitnesses = evaluate(forward_func) + pipeline.tell(fitnesses) + + + + # for i in range(100): + # forward_func = pipeline.ask(batch=True) + # fitnesses = evaluate(forward_func) + # pipeline.tell(fitnesses) + + +if __name__ == '__main__': + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..dfb91b6 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from .config import Configer diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0fdb80a330e145f48094b35382468ae61aaaa4a GIT binary patch literal 177 zcmYe~<>g`kf|;_xDRMyiF^Gc(44TX@8G*u@ zjJG(P^YhX&(^HH5G?{L(C4)IdAiXOYidcXYnE0iypPN^rpQ@jinpmQnl~@61L<2do m`lThAImP<%@tJv*{4Jil! literal 0 HcmV?d00001 diff --git a/utils/__pycache__/config.cpython-39.pyc b/utils/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34f737e32251fbe8855e8c580f85fb919310be43 GIT binary patch literal 2914 zcmZuzOK%)S5boE$*N$V7C^2BI0U<3Uao~VZWKbXpuY)CmM3h0%Fqv+zXEQsq>F!PJ z)%t>SMd83LinPlue}W&<;>Ia|ffH3dYiHK+th%#DChT3Tunc>ezCNBK{S zkbkjLemHb)Kr^3#kVMjy1jRoMsPD6YNhW#92Ox3`km#yDDj1axjdGY>#0g8l=hxTKOjC5JQ+pmQm_UID!#t6r~yUXyjN z*FbN`rq}DBugJDsdP-q-1A5EyjDORVE%>gpu6{Gmc9NdbF@D?8gL|$(y9v!)24P8` zPMMe7iW!~J1L}K~36*rl4%mcDs}nZE-LmcR`Ky~!?L?#0hOuAZYTx8px8_)GoES!S z*VSaAvuL1PEqY-5NIeX9l2nDE6a75N+=cMyuBGfMLi4`xjbnuqufC=+&z-2d)?PxypBrurjGrU3bbKPQq+Sf4#7FCc~~ zc_b`YO!y^Y#UaxIVi@8X`xAz8#?VF*fRS&HSMMk}%4C$;IfmCeW6?>nPPgjnPa>Tq zSr3py%?Tgb&BP>`u~8PQ04A=8~61kYdeXfy;GJktBz0@ds9UiW(?yZLB!?q?VL@MAyK zk{__}e*SnpuJ^V@4L*$zy1JD;FKemD$aFu92p;jm01Bn*a}?{tMmf?(^WbyANz?(+Y9H8 z_BIfrSH2sh-WcqEe z(Q7CcdC0X3qvGU#0E78O0qfWJ{PDKSR$ZXJmV?_4774iRG{bNJIpuFY?J$IVOXqJ| zkeFHj(FmA=h?CDv1o0ge|C^A0LNip>hN2DJLTEi>9a!>G6H oRy{ODAYtqTHJMwi??r&2x6_t!^{rwLU*bK|;DT6k>a5NG2fE(Rd;kCd literal 0 HcmV?d00001 diff --git a/utils/__pycache__/dotdict.cpython-39.pyc b/utils/__pycache__/dotdict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae439dc23013598a1bc02a1771669d8b4e1ecebe GIT binary patch literal 1754 zcmahJO>Z1E)b@C0XD1(ov_Oj>thhvr)I{POs+NWz^$-cAr-90vjGcBivoopfZRu)q zVI%d%i5sny2riNMMZR*%U&x{Go_9B!&{TNzo}ZuJJiqU3VId)~{`&R@`A3V8zwvN+ zaX8!qFb@HJPb87_2tbihkW$4Y+h=-929nEAMg^6v*G$H*SsEyM%vB)UCoBzBq{3dS z7x&sDN{LK%NOvv=eOFd6IhlJvD3Vf1Qud1MQAs5`iqimJIU=cl3HDI7WDI+xTCy#Z zBQm0C{EBdrwo$)pZPs?P$ZgIpJ%jEbd;nlp09i63J=(B>G{K+rS1_-k4LG;~u7Vo2 z?Rb^-RhPNY3=Wibu@FU7SRuOHaihvTjrtwW^b_Nv{TH~OQr8wDFEe9=I3?$6>;1~E ztMy7{_D(PR`GWWn4v*IdwkXZItgQrd*A8Clsg(u^LYL`qY2wV<^z|!Vp#?F(0|0Xu zP(zQ&2|1+b%X6Y{HNd$I{#XcRLwc+i{7xhTio+f>6fk|$8qr@M=sV<;Z=ciEb9!fZ zbG7p{Go7mLWY+3pcVJa#^)z&0S-*f-EZ(1Tztg(bx(!cbfG}c0_`sYp9U_;b0*$pc zMsvuV6XJ=ocq0VbZx9h0x{`?qe!(>R0{jgK z3l3R2VrH?SI|O3PM)VL0?U7YqX~%+{aOx6Xe&1_#Ir$GPMQoy zgvONVPY~Tg={Y}o`pQ17r8mGWOvX{J$`U$rmVGQ+jm@YjAnYiEZ2(5`xD`ini&1}r zZQaj`N{BNQJo7`3sx$lph9k|>=>CIErS|ZUIlAk(u4}v9P4wp=;&pge(z!|*tE3Ox z*f=Jo3lDTrS-l3bX_Vy>H0iJK$C;^~@^d5J7zg97w-*Dnv=%d5{2%s~zm1Z4`Y%&h Bi8BBI literal 0 HcmV?d00001 diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..2d1dbda --- /dev/null +++ b/utils/config.py @@ -0,0 +1,78 @@ +import json +import os +import warnings + +from .dotdict import DotDict + + +class Configer: + @classmethod + def __load_default_config(cls): + par_dir = os.path.dirname(os.path.abspath(__file__)) + default_config_path = os.path.join(par_dir, "./default_config.json") + return cls.__load_config(default_config_path) + + @classmethod + def __load_config(cls, config_path): + with open(config_path, "r") as f: + text = "".join(f.readlines()) + try: + j = json.loads(text) + except ValueError: + raise Exception("Invalid config") + return DotDict.from_dict(j, "root") + + @classmethod + def __check_redundant_config(cls, default_config, config): + for key in config: + if key not in default_config: + warnings.warn(f"Redundant config: {key} in {config.name}") + continue + if isinstance(default_config[key], DotDict): + cls.__check_redundant_config(default_config[key], config[key]) + + @classmethod + def __complete_config(cls, default_config, config): + for key in default_config: + if key not in config: + config[key] = default_config[key] + continue + if isinstance(default_config[key], DotDict): + cls.__complete_config(default_config[key], config[key]) + + @classmethod + def __decorate_config(cls, config): + if config.neat.gene.activation.options == 'all': + config.neat.gene.activation.options = [ + "sigmoid", "tanh", "sin", "gauss", "relu", "elu", "lelu", "selu", "softplus", "identity", "clamped", + "inv", "log", "exp", "abs", "hat", "square", "cube" + ] + if isinstance(config.neat.gene.activation.options, str): + config.neat.gene.activation.options = [config.neat.gene.activation.options] + + if config.neat.gene.aggregation.options == 'all': + config.neat.gene.aggregation.options = ["product", "sum", "max", "min", "median", "mean"] + if isinstance(config.neat.gene.aggregation.options, str): + config.neat.gene.aggregation.options = [config.neat.gene.aggregation.options] + + @classmethod + def load_config(cls, config_path=None): + default_config = cls.__load_default_config() + if config_path is None: + config = DotDict("root") + elif not os.path.exists(config_path): + warnings.warn(f"config file {config_path} not exist!") + config = DotDict("root") + else: + config = cls.__load_config(config_path) + + cls.__check_redundant_config(default_config, config) + cls.__complete_config(default_config, config) + cls.__decorate_config(config) + return config + + @classmethod + def write_config(cls, config, write_path): + text = json.dumps(config, indent=2) + with open(write_path, "w") as f: + f.write(text) diff --git a/utils/default_config.json b/utils/default_config.json new file mode 100644 index 0000000..e4db63b --- /dev/null +++ b/utils/default_config.json @@ -0,0 +1,108 @@ +{ + "basic": { + "num_inputs": 2, + "num_outputs": 1, + "init_maximum_nodes": 20, + "expands_coe": 1.5 + }, + "neat": { + "population": { + "fitness_criterion": "max", + "fitness_threshold": 43.9999, + "generation_limit": 100, + "pop_size": 1000, + "reset_on_extinction": "False" + }, + "gene": { + "bias": { + "init_mean": 0.0, + "init_stdev": 1.0, + "max_value": 30.0, + "min_value": -30.0, + "mutate_power": 0.5, + "mutate_rate": 0.7, + "replace_rate": 0.1 + }, + "response": { + "init_mean": 1.0, + "init_stdev": 0.0, + "max_value": 30.0, + "min_value": -30.0, + "mutate_power": 0.0, + "mutate_rate": 0.0, + "replace_rate": 0.0 + }, + "activation": { + "default": "sigmoid", + "options": "sigmoid", + "mutate_rate": 0.01 + }, + "aggregation": { + "default": "sum", + "options": [ + "product", + "sum", + "max", + "min", + "median", + "mean" + ], + "mutate_rate": 0.01 + }, + "weight": { + "init_mean": 0.0, + "init_stdev": 1.0, + "max_value": 30.0, + "min_value": -30.0, + "mutate_power": 0.5, + "mutate_rate": 0.8, + "replace_rate": 0.1 + }, + "enabled": { + "mutate_rate": 0.01 + } + }, + "genome": { + "compatibility_disjoint_coefficient": 1.0, + "compatibility_weight_coefficient": 0.5, + "feedforward": "True", + "single_structural_mutation": "False", + "conn_add_prob": 0.5, + "conn_delete_prob": 0.5, + "node_add_prob": 0.2, + "node_delete_prob": 0.2 + }, + "species": { + "compatibility_threshold": 3.5, + "species_fitness_func": "max", + "max_stagnation": 20, + "species_elitism": 2, + "genome_elitism": 2, + "survival_threshold": 0.2, + "min_species_size": 1 + } + }, + "hyperneat": { + "substrate": { + "type": "feedforward", + "layers": [ + 3, + 10, + 10, + 1 + ], + "x_lim": [ + -5, + 5 + ], + "y_lim": [ + -5, + 5 + ], + "threshold": 0.2, + "max_weight": 5.0 + } + }, + "es-hyperneat": { + } +} \ No newline at end of file diff --git a/utils/dotdict.py b/utils/dotdict.py new file mode 100644 index 0000000..713c2b1 --- /dev/null +++ b/utils/dotdict.py @@ -0,0 +1,61 @@ +# DotDict For Config. Case Insensitive. + +class DotDict(dict): + def __init__(self, name, *args, **kwargs): + super().__init__(*args, **kwargs) + self["name"] = name + + def __getattr__(self, attr): + attr = attr.lower() # case insensitive + if attr in self: + return self[attr] + else: + raise AttributeError(f"'{self.__class__.__name__}-{self.name}' has no attribute '{attr}'") + + def __setattr__(self, attr, value): + attr = attr.lower() # case insensitive + if attr not in self: + raise AttributeError(f"'{self.__class__.__name__}-{self.name}' has no attribute '{attr}'") + self[attr] = value + + def __delattr__(self, attr): + attr = attr.lower() # case insensitive + if attr in self: + del self[attr] + else: + raise AttributeError(f"{self.__class__.__name__}-{self.name} object has no attribute '{attr}'") + + @classmethod + def from_dict(cls, d, name): + if not isinstance(d, dict): + return d + + dot_dict = cls(name) + for key, value in d.items(): + key = key.lower() # case insensitive + if isinstance(value, dict): + dot_dict[key] = cls.from_dict(value, key) + else: + dot_dict[key] = value + if dot_dict[key] == "True": # Fuck! Json has no bool type! + dot_dict[key] = True + if dot_dict[key] == "False": + dot_dict[key] = False + if dot_dict[key] == "None": + dot_dict[key] = None + return dot_dict + + +if __name__ == '__main__': + nested_dict = { + "a": 1, + "b": { + "c": 2, + "ACDeef": { + "e": 3 + } + } + } + + dd = DotDict.from_dict(nested_dict, "root") + print(dd.b.acdeef.e) # 输出:3