modifying
This commit is contained in:
@@ -2,11 +2,15 @@ import os
|
||||
import warnings
|
||||
import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .activations import refactor_act
|
||||
from .aggregations import refactor_agg
|
||||
|
||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||
jit_config_keys = [
|
||||
"input_idx",
|
||||
"output_idx",
|
||||
"compatibility_disjoint",
|
||||
"compatibility_weight",
|
||||
"conn_add_prob",
|
||||
@@ -88,10 +92,14 @@ class Configer:
|
||||
|
||||
refactor_act(config)
|
||||
refactor_agg(config)
|
||||
|
||||
input_idx = np.arange(config['num_inputs'])
|
||||
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
config['input_idx'] = input_idx
|
||||
config['output_idx'] = output_idx
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def create_jit_config(cls, config):
|
||||
jit_config = {k: config[k] for k in jit_config_keys}
|
||||
|
||||
return jit_config
|
||||
|
||||
Reference in New Issue
Block a user