asreviewcontrib.simulation.api.unwrapping

Offers functionality to unwrap the abstraction layer provided by asreviewcontrib.simulation, in order to get objects from the asreview library directly.

Example usage:
from asreviewcontrib.simulation.api.unwrapping import get_review_simulate_kwargs


# (do something interesting with get_review_simulate_kwargs)
 1"""
 2Offers functionality to unwrap the abstraction layer provided by
 3`asreviewcontrib.simulation`, in order to get objects from the
 4`asreview` library directly.
 5
 6Example usage:
 7
 8    ```python
 9    from asreviewcontrib.simulation.api.unwrapping import get_review_simulate_kwargs
10
11
12    # (do something interesting with get_review_simulate_kwargs)
13    ```
14"""
15from asreviewcontrib.simulation._private.lib.unwrapping.get_review_simulate_kwargs import get_review_simulate_kwargs
16from asreviewcontrib.simulation._private.lib.unwrapping.instantiate_unwrapped_model import instantiate_unwrapped_model
17
18
19__all__ = [
20    "get_review_simulate_kwargs",
21    "instantiate_unwrapped_model",
22]
def get_review_simulate_kwargs( config: asreviewcontrib.simulation._private.lib.config.Config, as_data: asreview.data.base.ASReviewData, seed: Optional[int] = None) -> dict:
11def get_review_simulate_kwargs(config: Config, as_data: ASReviewData, seed: Optional[int] = None) -> dict:
12    # Initialize the random state
13    random_state = numpy.random.RandomState(seed)
14
15    # assign model parameterizations using the configuration from 'models'
16    classifier = instantiate_unwrapped_model(config.clr, random_state=random_state)
17    querier = instantiate_unwrapped_model(config.qry, random_state=random_state)
18    balancer = instantiate_unwrapped_model(config.bal, random_state=random_state)
19    extractor = instantiate_unwrapped_model(config.fex, random_state=random_state)
20
21    if config.clr.abbr in ["clr-lstm-base", "clr-lstm-pool"]:
22        embedding_fp = config.fex.params.get("embedding")
23        classifier.embedding_matrix = extractor.get_embedding_matrix(as_data.texts, embedding_fp)
24
25    n_papers = None
26    stop_if = unwrap_stopping_vars(config, as_data)
27    prior_indices, n_prior_included, n_prior_excluded, init_seed = unwrap_prior_sampling_vars(config, as_data)
28
29    return {
30        "model": classifier,
31        "query_model": querier,
32        "balance_model": balancer,
33        "feature_model": extractor,
34        "n_papers": n_papers,
35        "n_instances": config.qry.params.get("n_instances"),
36        "stop_if": stop_if,
37        "prior_indices": prior_indices,
38        "n_prior_included": n_prior_included,
39        "n_prior_excluded": n_prior_excluded,
40        "init_seed": init_seed,
41    }
def instantiate_unwrapped_model( model: asreviewcontrib.simulation._private.lib.config.OneModelConfig, random_state):
27def instantiate_unwrapped_model(model: OneModelConfig, random_state):
28    assert isinstance(model, OneModelConfig), "Input argument 'model' needs to be an instance of OneModelConfig"
29    my_instantiators = {
30        "bal-double": instantiate_unwrapped_bal_double,
31        "bal-simple": instantiate_unwrapped_bal_simple,
32        "bal-undersample": instantiate_unwrapped_bal_undersample,
33        "clr-logistic": instantiate_unwrapped_clr_logistic,
34        "clr-lstm-base": instantiate_unwrapped_clr_lstm_base,
35        "clr-lstm-pool": instantiate_unwrapped_clr_lstm_pool,
36        "clr-nb": instantiate_unwrapped_clr_nb,
37        "clr-nn-2-layer": instantiate_unwrapped_clr_nn_2_layer,
38        "clr-rf": instantiate_unwrapped_clr_rf,
39        "clr-svm": instantiate_unwrapped_clr_svm,
40        "fex-doc2vec": instantiate_unwrapped_fex_doc2vec,
41        "fex-embedding-idf": instantiate_unwrapped_fex_embedding_idf,
42        "fex-embedding-lstm": instantiate_unwrapped_fex_embedding_lstm,
43        "fex-sbert": instantiate_unwrapped_fex_sbert,
44        "fex-tfidf": instantiate_unwrapped_fex_tfidf,
45        "qry-cluster": instantiate_unwrapped_qry_cluster,
46        "qry-max": instantiate_unwrapped_qry_max,
47        "qry-max-random": instantiate_unwrapped_qry_max_random,
48        "qry-max-uncertainty": instantiate_unwrapped_qry_max_uncertainty,
49        "qry-random": instantiate_unwrapped_qry_random,
50        "qry-uncertainty": instantiate_unwrapped_qry_uncertainty,
51    }
52
53    recognized_model_flavors = {"bal", "clr", "fex", "qry"}
54    other_instantiators = [{abbr: q.impl} for abbr, q in get_quads() if abbr[:3] in recognized_model_flavors]
55
56    instantiators = my_instantiators
57    for other_instantiator in other_instantiators:
58        instantiators.update(other_instantiator)
59
60    try:
61        return instantiators[model.abbr](model.params, random_state)
62    except KeyError:
63        abbrs = "\n".join([key for key in instantiators.keys()])
64        msg = f"Undefined behavior for model name f{model.abbr}. Valid model names are: f{abbrs}"
65        raise KeyError(msg)