modifying

This commit is contained in:
wls2002
2023-06-19 17:32:34 +08:00
parent 5cbe3c14bb
commit 35b095ba74
6 changed files with 428 additions and 42 deletions

View File

@@ -2,11 +2,15 @@ import os
import warnings
import configparser
import numpy as np
from .activations import refactor_act
from .aggregations import refactor_agg
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [
"input_idx",
"output_idx",
"compatibility_disjoint",
"compatibility_weight",
"conn_add_prob",
@@ -88,10 +92,14 @@ class Configer:
refactor_act(config)
refactor_agg(config)
input_idx = np.arange(config['num_inputs'])
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
config['input_idx'] = input_idx
config['output_idx'] = output_idx
return config
@classmethod
def create_jit_config(cls, config):
jit_config = {k: config[k] for k in jit_config_keys}
return jit_config