asreviewcontrib.simulation.api.unwrapping
Offers functionality to unwrap the abstraction layer provided by
asreviewcontrib.simulation, in order to get objects from the
asreview library directly.
Example usage:
from asreviewcontrib.simulation.api.unwrapping import get_review_simulate_kwargs # (do something interesting with get_review_simulate_kwargs)
1""" 2Offers functionality to unwrap the abstraction layer provided by 3`asreviewcontrib.simulation`, in order to get objects from the 4`asreview` library directly. 5 6Example usage: 7 8 ```python 9 from asreviewcontrib.simulation.api.unwrapping import get_review_simulate_kwargs 10 11 12 # (do something interesting with get_review_simulate_kwargs) 13 ``` 14""" 15from asreviewcontrib.simulation._private.lib.unwrapping.get_review_simulate_kwargs import get_review_simulate_kwargs 16from asreviewcontrib.simulation._private.lib.unwrapping.instantiate_unwrapped_model import instantiate_unwrapped_model 17 18 19__all__ = [ 20 "get_review_simulate_kwargs", 21 "instantiate_unwrapped_model", 22]
def
get_review_simulate_kwargs( config: asreviewcontrib.simulation._private.lib.config.Config, as_data: asreview.data.base.ASReviewData, seed: Optional[int] = None) -> dict:
11def get_review_simulate_kwargs(config: Config, as_data: ASReviewData, seed: Optional[int] = None) -> dict: 12 # Initialize the random state 13 random_state = numpy.random.RandomState(seed) 14 15 # assign model parameterizations using the configuration from 'models' 16 classifier = instantiate_unwrapped_model(config.clr, random_state=random_state) 17 querier = instantiate_unwrapped_model(config.qry, random_state=random_state) 18 balancer = instantiate_unwrapped_model(config.bal, random_state=random_state) 19 extractor = instantiate_unwrapped_model(config.fex, random_state=random_state) 20 21 if config.clr.abbr in ["clr-lstm-base", "clr-lstm-pool"]: 22 embedding_fp = config.fex.params.get("embedding") 23 classifier.embedding_matrix = extractor.get_embedding_matrix(as_data.texts, embedding_fp) 24 25 n_papers = None 26 stop_if = unwrap_stopping_vars(config, as_data) 27 prior_indices, n_prior_included, n_prior_excluded, init_seed = unwrap_prior_sampling_vars(config, as_data) 28 29 return { 30 "model": classifier, 31 "query_model": querier, 32 "balance_model": balancer, 33 "feature_model": extractor, 34 "n_papers": n_papers, 35 "n_instances": config.qry.params.get("n_instances"), 36 "stop_if": stop_if, 37 "prior_indices": prior_indices, 38 "n_prior_included": n_prior_included, 39 "n_prior_excluded": n_prior_excluded, 40 "init_seed": init_seed, 41 }
def
instantiate_unwrapped_model( model: asreviewcontrib.simulation._private.lib.config.OneModelConfig, random_state):
27def instantiate_unwrapped_model(model: OneModelConfig, random_state): 28 assert isinstance(model, OneModelConfig), "Input argument 'model' needs to be an instance of OneModelConfig" 29 my_instantiators = { 30 "bal-double": instantiate_unwrapped_bal_double, 31 "bal-simple": instantiate_unwrapped_bal_simple, 32 "bal-undersample": instantiate_unwrapped_bal_undersample, 33 "clr-logistic": instantiate_unwrapped_clr_logistic, 34 "clr-lstm-base": instantiate_unwrapped_clr_lstm_base, 35 "clr-lstm-pool": instantiate_unwrapped_clr_lstm_pool, 36 "clr-nb": instantiate_unwrapped_clr_nb, 37 "clr-nn-2-layer": instantiate_unwrapped_clr_nn_2_layer, 38 "clr-rf": instantiate_unwrapped_clr_rf, 39 "clr-svm": instantiate_unwrapped_clr_svm, 40 "fex-doc2vec": instantiate_unwrapped_fex_doc2vec, 41 "fex-embedding-idf": instantiate_unwrapped_fex_embedding_idf, 42 "fex-embedding-lstm": instantiate_unwrapped_fex_embedding_lstm, 43 "fex-sbert": instantiate_unwrapped_fex_sbert, 44 "fex-tfidf": instantiate_unwrapped_fex_tfidf, 45 "qry-cluster": instantiate_unwrapped_qry_cluster, 46 "qry-max": instantiate_unwrapped_qry_max, 47 "qry-max-random": instantiate_unwrapped_qry_max_random, 48 "qry-max-uncertainty": instantiate_unwrapped_qry_max_uncertainty, 49 "qry-random": instantiate_unwrapped_qry_random, 50 "qry-uncertainty": instantiate_unwrapped_qry_uncertainty, 51 } 52 53 recognized_model_flavors = {"bal", "clr", "fex", "qry"} 54 other_instantiators = [{abbr: q.impl} for abbr, q in get_quads() if abbr[:3] in recognized_model_flavors] 55 56 instantiators = my_instantiators 57 for other_instantiator in other_instantiators: 58 instantiators.update(other_instantiator) 59 60 try: 61 return instantiators[model.abbr](model.params, random_state) 62 except KeyError: 63 abbrs = "\n".join([key for key in instantiators.keys()]) 64 msg = f"Undefined behavior for model name f{model.abbr}. Valid model names are: f{abbrs}" 65 raise KeyError(msg)