Best Python code snippet using slash
test_problems.py
Source:test_problems.py  
1#2# This file is part of the chi repository3# (https://github.com/DavAug/chi/) which is released under the4# BSD 3-clause license. See accompanying LICENSE.md for copyright notice and5# full license details.6#7import copy8import unittest9import numpy as np10import pandas as pd11import pints12import chi13from chi.library import ModelLibrary14class TestProblemModellingControllerPDProblem(unittest.TestCase):15    """16    Tests the chi.ProblemModellingController class on a PD modelling17    problem.18    """19    @classmethod20    def setUpClass(cls):21        # Create test dataset22        ids_v = [0, 0, 0, 1, 1, 1, 2, 2]23        times_v = [0, 1, 2, 2, np.nan, 4, 1, 3]24        volumes = [np.nan, 0.3, 0.2, 0.5, 0.1, 0.2, 0.234, np.nan]25        ids_c = [0, 0, 1, 1]26        times_c = [0, 1, 2, np.nan]27        cytokines = [3.4, 0.3, 0.5, np.nan]28        ids_d = [0, 1, 1, 1, 2, 2]29        times_d = [0, np.nan, 4, 1, 3, 3]30        dose = [3.4, np.nan, 0.5, 0.5, np.nan, np.nan]31        duration = [0.01, np.nan, 0.31, np.nan, 0.5, np.nan]32        ids_cov = [0, 1, 2]33        times_cov = [np.nan, 1, np.nan]34        age = [10, 14, 12]35        cls.data = pd.DataFrame({36            'ID': ids_v + ids_c + ids_d + ids_cov,37            'Time': times_v + times_c + times_d + times_cov,38            'Observable':39                ['Tumour volume'] * 8 + ['IL 6'] * 4 + [np.nan] * 6 +40                ['Age'] * 3,41            'Value': volumes + cytokines + [np.nan] * 6 + age,42            'Dose': [np.nan] * 12 + dose + [np.nan] * 3,43            'Duration': [np.nan] * 12 + duration + [np.nan] * 3})44        # Test case I: create PD modelling problem45        lib = ModelLibrary()46        path = lib.tumour_growth_inhibition_model_koch()47        cls.pd_model = chi.PharmacodynamicModel(path)48        cls.error_model = chi.ConstantAndMultiplicativeGaussianErrorModel()49        cls.pd_problem = chi.ProblemModellingController(50            cls.pd_model, cls.error_model)51        # Test case II: create PKPD modelling problem52        lib = ModelLibrary()53        path = lib.erlotinib_tumour_growth_inhibition_model()54        cls.pkpd_model = chi.PharmacokineticModel(path)55        cls.pkpd_model.set_outputs([56            'central.drug_concentration',57            'myokit.tumour_volume'])58        cls.error_models = [59            chi.ConstantAndMultiplicativeGaussianErrorModel(),60            chi.ConstantAndMultiplicativeGaussianErrorModel()]61        cls.pkpd_problem = chi.ProblemModellingController(62            cls.pkpd_model, cls.error_models,63            outputs=[64                'central.drug_concentration',65                'myokit.tumour_volume'])66    def test_bad_input(self):67        # Mechanistic model has wrong type68        mechanistic_model = 'wrong type'69        with self.assertRaisesRegex(TypeError, 'The mechanistic model'):70            chi.ProblemModellingController(71                mechanistic_model, self.error_model)72        # Error model has wrong type73        error_model = 'wrong type'74        with self.assertRaisesRegex(TypeError, 'Error models have to be'):75            chi.ProblemModellingController(76                self.pd_model, error_model)77        error_models = ['wrong', 'type']78        with self.assertRaisesRegex(TypeError, 'Error models have to be'):79            chi.ProblemModellingController(80                self.pd_model, error_models)81        # Wrong number of error models82        error_model = chi.ConstantAndMultiplicativeGaussianErrorModel()83        with self.assertRaisesRegex(ValueError, 'Wrong number of error'):84            chi.ProblemModellingController(85                self.pkpd_model, error_model)86        error_models = [87            chi.ConstantAndMultiplicativeGaussianErrorModel(),88            chi.ConstantAndMultiplicativeGaussianErrorModel()]89        with self.assertRaisesRegex(ValueError, 'Wrong number of error'):90            chi.ProblemModellingController(91                self.pd_model, error_models)92    def test_fix_parameters(self):93        # Test case I: PD model94        # Fix model parameters95        name_value_dict = dict({96            'myokit.drug_concentration': 0,97            'Sigma base': 1})98        self.pd_problem.fix_parameters(name_value_dict)99        self.assertEqual(self.pd_problem.get_n_parameters(), 5)100        param_names = self.pd_problem.get_parameter_names()101        self.assertEqual(len(param_names), 5)102        self.assertEqual(param_names[0], 'myokit.tumour_volume')103        self.assertEqual(param_names[1], 'myokit.kappa')104        self.assertEqual(param_names[2], 'myokit.lambda_0')105        self.assertEqual(param_names[3], 'myokit.lambda_1')106        self.assertEqual(param_names[4], 'Sigma rel.')107        # Free and fix a parameter108        name_value_dict = dict({109            'myokit.lambda_1': 2,110            'Sigma base': None})111        self.pd_problem.fix_parameters(name_value_dict)112        self.assertEqual(self.pd_problem.get_n_parameters(), 5)113        param_names = self.pd_problem.get_parameter_names()114        self.assertEqual(len(param_names), 5)115        self.assertEqual(param_names[0], 'myokit.tumour_volume')116        self.assertEqual(param_names[1], 'myokit.kappa')117        self.assertEqual(param_names[2], 'myokit.lambda_0')118        self.assertEqual(param_names[3], 'Sigma base')119        self.assertEqual(param_names[4], 'Sigma rel.')120        # Free all parameters again121        name_value_dict = dict({122            'myokit.lambda_1': None,123            'myokit.drug_concentration': None})124        self.pd_problem.fix_parameters(name_value_dict)125        self.assertEqual(self.pd_problem.get_n_parameters(), 7)126        param_names = self.pd_problem.get_parameter_names()127        self.assertEqual(len(param_names), 7)128        self.assertEqual(param_names[0], 'myokit.tumour_volume')129        self.assertEqual(param_names[1], 'myokit.drug_concentration')130        self.assertEqual(param_names[2], 'myokit.kappa')131        self.assertEqual(param_names[3], 'myokit.lambda_0')132        self.assertEqual(param_names[4], 'myokit.lambda_1')133        self.assertEqual(param_names[5], 'Sigma base')134        self.assertEqual(param_names[6], 'Sigma rel.')135        # Fix parameters before setting a population model136        problem = copy.copy(self.pd_problem)137        name_value_dict = dict({138            'myokit.tumour_volume': 1,139            'myokit.drug_concentration': 0,140            'myokit.kappa': 1,141            'myokit.lambda_1': 2})142        problem.fix_parameters(name_value_dict)143        problem.set_population_model(144            pop_models=[145                chi.HeterogeneousModel(),146                chi.PooledModel(),147                chi.LogNormalModel()])148        problem.set_data(149            self.data,150            output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})151        n_ids = 3152        self.assertEqual(problem.get_n_parameters(), 2 * n_ids + 1 + 2)153        param_names = problem.get_parameter_names()154        self.assertEqual(len(param_names), 9)155        self.assertEqual(param_names[0], 'ID 0: myokit.lambda_0')156        self.assertEqual(param_names[1], 'ID 1: myokit.lambda_0')157        self.assertEqual(param_names[2], 'ID 2: myokit.lambda_0')158        self.assertEqual(param_names[3], 'Pooled Sigma base')159        self.assertEqual(param_names[4], 'ID 0: Sigma rel.')160        self.assertEqual(param_names[5], 'ID 1: Sigma rel.')161        self.assertEqual(param_names[6], 'ID 2: Sigma rel.')162        self.assertEqual(param_names[7], 'Mean log Sigma rel.')163        self.assertEqual(param_names[8], 'Std. log Sigma rel.')164        # Fix parameters after setting a population model165        # (Only population models can be fixed)166        name_value_dict = dict({167            'ID 1: myokit.lambda_0': 1,168            'ID 2: myokit.lambda_0': 4,169            'Pooled Sigma base': 2})170        problem.fix_parameters(name_value_dict)171        # self.assertEqual(problem.get_n_parameters(), 8)172        param_names = problem.get_parameter_names()173        self.assertEqual(len(param_names), 8)174        self.assertEqual(param_names[0], 'ID 0: myokit.lambda_0')175        self.assertEqual(param_names[1], 'ID 1: myokit.lambda_0')176        self.assertEqual(param_names[2], 'ID 2: myokit.lambda_0')177        self.assertEqual(param_names[3], 'ID 0: Sigma rel.')178        self.assertEqual(param_names[4], 'ID 1: Sigma rel.')179        self.assertEqual(param_names[5], 'ID 2: Sigma rel.')180        self.assertEqual(param_names[6], 'Mean log Sigma rel.')181        self.assertEqual(param_names[7], 'Std. log Sigma rel.')182        # Test case II: PKPD model183        # Fix model parameters184        name_value_dict = dict({185            'myokit.kappa': 0,186            'central.drug_concentration Sigma base': 1})187        self.pkpd_problem.fix_parameters(name_value_dict)188        self.assertEqual(self.pkpd_problem.get_n_parameters(), 9)189        param_names = self.pkpd_problem.get_parameter_names()190        self.assertEqual(len(param_names), 9)191        self.assertEqual(param_names[0], 'central.drug_amount')192        self.assertEqual(param_names[1], 'myokit.tumour_volume')193        self.assertEqual(param_names[2], 'central.size')194        self.assertEqual(param_names[3], 'myokit.critical_volume')195        self.assertEqual(param_names[4], 'myokit.elimination_rate')196        self.assertEqual(param_names[5], 'myokit.lambda')197        self.assertEqual(198            param_names[6], 'central.drug_concentration Sigma rel.')199        self.assertEqual(param_names[7], 'myokit.tumour_volume Sigma base')200        self.assertEqual(param_names[8], 'myokit.tumour_volume Sigma rel.')201        # Free and fix a parameter202        name_value_dict = dict({203            'myokit.lambda': 2,204            'myokit.kappa': None})205        self.pkpd_problem.fix_parameters(name_value_dict)206        self.assertEqual(self.pkpd_problem.get_n_parameters(), 9)207        param_names = self.pkpd_problem.get_parameter_names()208        self.assertEqual(len(param_names), 9)209        self.assertEqual(param_names[0], 'central.drug_amount')210        self.assertEqual(param_names[1], 'myokit.tumour_volume')211        self.assertEqual(param_names[2], 'central.size')212        self.assertEqual(param_names[3], 'myokit.critical_volume')213        self.assertEqual(param_names[4], 'myokit.elimination_rate')214        self.assertEqual(param_names[5], 'myokit.kappa')215        self.assertEqual(216            param_names[6], 'central.drug_concentration Sigma rel.')217        self.assertEqual(param_names[7], 'myokit.tumour_volume Sigma base')218        self.assertEqual(param_names[8], 'myokit.tumour_volume Sigma rel.')219        # Free all parameters again220        name_value_dict = dict({221            'myokit.lambda': None,222            'central.drug_concentration Sigma base': None})223        self.pkpd_problem.fix_parameters(name_value_dict)224        self.assertEqual(self.pkpd_problem.get_n_parameters(), 11)225        param_names = self.pkpd_problem.get_parameter_names()226        self.assertEqual(len(param_names), 11)227        self.assertEqual(param_names[0], 'central.drug_amount')228        self.assertEqual(param_names[1], 'myokit.tumour_volume')229        self.assertEqual(param_names[2], 'central.size')230        self.assertEqual(param_names[3], 'myokit.critical_volume')231        self.assertEqual(param_names[4], 'myokit.elimination_rate')232        self.assertEqual(param_names[5], 'myokit.kappa')233        self.assertEqual(param_names[6], 'myokit.lambda')234        self.assertEqual(235            param_names[7], 'central.drug_concentration Sigma base')236        self.assertEqual(237            param_names[8], 'central.drug_concentration Sigma rel.')238        self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')239        self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')240    def test_fix_parameters_bad_input(self):241        # Input is not a dictionary242        name_value_dict = 'Bad type'243        with self.assertRaisesRegex(ValueError, 'The name-value dictionary'):244            self.pd_problem.fix_parameters(name_value_dict)245    def test_get_covariate_names(self):246        # Test case I: PD model247        problem = copy.deepcopy(self.pd_problem)248        # I.1: No population model249        names = problem.get_covariate_names()250        self.assertEqual(len(names), 0)251        # I.2: Population model but no covariate population model252        pop_models = [chi.PooledModel()] * 7253        problem.set_population_model(pop_models)254        names = problem.get_covariate_names()255        self.assertEqual(len(names), 0)256        names = problem.get_covariate_names(unique=False)257        self.assertEqual(len(names), 7)258        self.assertEqual(names[0], [])259        self.assertEqual(names[1], [])260        self.assertEqual(names[2], [])261        self.assertEqual(names[3], [])262        self.assertEqual(names[3], [])263        self.assertEqual(names[3], [])264        self.assertEqual(names[3], [])265        # I.3: With covariate models266        cov_pop_model1 = chi.CovariatePopulationModel(267            chi.GaussianModel(),268            chi.LogNormalLinearCovariateModel(n_covariates=2)269        )270        cov_pop_model1.set_covariate_names(['Age', 'Sex'])271        cov_pop_model2 = chi.CovariatePopulationModel(272            chi.GaussianModel(),273            chi.LogNormalLinearCovariateModel(n_covariates=3)274        )275        cov_pop_model2.set_covariate_names(['SNP', 'Age', 'Height'])276        pop_models = [277            chi.PooledModel(),278            cov_pop_model1,279            chi.PooledModel(),280            cov_pop_model2,281            cov_pop_model1,282            chi.PooledModel(),283            chi.PooledModel()284        ]285        problem.set_population_model(pop_models)286        names = problem.get_covariate_names()287        self.assertEqual(len(names), 4)288        self.assertEqual(names[0], 'Age')289        self.assertEqual(names[1], 'Sex')290        self.assertEqual(names[2], 'SNP')291        self.assertEqual(names[3], 'Height')292        names = problem.get_covariate_names(unique=False)293        self.assertEqual(len(names), 7)294        self.assertEqual(names[0], [])295        self.assertEqual(names[1], ['Age', 'Sex'])296        self.assertEqual(names[2], [])297        self.assertEqual(names[3], ['SNP', 'Age', 'Height'])298        self.assertEqual(names[4], ['Age', 'Sex'])299        self.assertEqual(names[5], [])300        self.assertEqual(names[6], [])301    def test_get_dosing_regimens(self):302        # Test case I: PD problem303        problem = copy.deepcopy(self.pd_problem)304        # No data has been set305        regimens = problem.get_dosing_regimens()306        self.assertIsNone(regimens)307        # Set data, but because PD model, no dosing regimen can be set308        problem.set_data(self.data, {'myokit.tumour_volume': 'Tumour volume'})309        regimens = problem.get_dosing_regimens()310        self.assertIsNone(regimens)311        # Test case II: PKPD problem312        problem = copy.deepcopy(self.pkpd_problem)313        # No data has been set314        regimens = problem.get_dosing_regimens()315        self.assertIsNone(regimens)316        # Data has been set, but duration is ignored317        problem.set_data(318            self.data,319            output_observable_dict={320                'myokit.tumour_volume': 'Tumour volume',321                'central.drug_concentration': 'IL 6'},322            dose_duration_key=None)323        regimens = problem.get_dosing_regimens()324        self.assertIsInstance(regimens, dict)325        # Data has been set with duration information326        problem.set_data(327            self.data,328            output_observable_dict={329                'myokit.tumour_volume': 'Tumour volume',330                'central.drug_concentration': 'IL 6'})331        regimens = problem.get_dosing_regimens()332        self.assertIsInstance(regimens, dict)333    def test_get_log_prior(self):334        # Log-prior is extensively tested with get_log_posterior335        # method336        self.assertIsNone(self.pd_problem.get_log_prior())337    def test_get_log_posterior(self):338        # Test case I: Create posterior with no fixed parameters339        problem = copy.deepcopy(self.pd_problem)340        # Set data which does not provide measurements for all IDs341        problem.set_data(342            self.data,343            output_observable_dict={'myokit.tumour_volume': 'IL 6'})344        problem.set_log_prior([345            pints.HalfCauchyLogPrior(0, 1)]*7)346        # Get all posteriors347        posteriors = problem.get_log_posterior()348        self.assertEqual(len(posteriors), 2)349        self.assertEqual(posteriors[0].n_parameters(), 7)350        self.assertEqual(posteriors[0].get_id(), 'ID 0')351        self.assertEqual(posteriors[1].n_parameters(), 7)352        self.assertEqual(posteriors[1].get_id(), 'ID 1')353        # Set data that has measurements for all IDs354        problem.set_data(355            self.data,356            output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})357        problem.set_log_prior([358            pints.HalfCauchyLogPrior(0, 1)]*7)359        # Get all posteriors360        posteriors = problem.get_log_posterior()361        self.assertEqual(len(posteriors), 3)362        self.assertEqual(posteriors[0].n_parameters(), 7)363        self.assertEqual(posteriors[0].get_id(), 'ID 0')364        self.assertEqual(posteriors[1].n_parameters(), 7)365        self.assertEqual(posteriors[1].get_id(), 'ID 1')366        self.assertEqual(posteriors[2].n_parameters(), 7)367        self.assertEqual(posteriors[2].get_id(), 'ID 2')368        # Get only one posterior369        posterior = problem.get_log_posterior(individual='0')370        self.assertIsInstance(posterior, chi.LogPosterior)371        self.assertEqual(posterior.n_parameters(), 7)372        self.assertEqual(posterior.get_id(), 'ID 0')373        # Test case II: Fix some parameters374        name_value_dict = dict({375            'myokit.drug_concentration': 0,376            'myokit.kappa': 1})377        problem.fix_parameters(name_value_dict)378        problem.set_log_prior([379            pints.HalfCauchyLogPrior(0, 1)]*5)380        # Get all posteriors381        posteriors = problem.get_log_posterior()382        self.assertEqual(len(posteriors), 3)383        self.assertEqual(posteriors[0].n_parameters(), 5)384        self.assertEqual(posteriors[0].get_id(), 'ID 0')385        self.assertEqual(posteriors[1].n_parameters(), 5)386        self.assertEqual(posteriors[1].get_id(), 'ID 1')387        self.assertEqual(posteriors[2].n_parameters(), 5)388        self.assertEqual(posteriors[2].get_id(), 'ID 2')389        # Get only one posterior390        posterior = problem.get_log_posterior(individual='1')391        self.assertIsInstance(posterior, chi.LogPosterior)392        self.assertEqual(posterior.n_parameters(), 5)393        self.assertEqual(posterior.get_id(), 'ID 1')394        # Set a population model395        cov_pop_model = chi.CovariatePopulationModel(396            chi.GaussianModel(),397            chi.LogNormalLinearCovariateModel(n_covariates=1)398        )399        cov_pop_model.set_covariate_names(['Age'], True)400        pop_models = [401            chi.PooledModel(),402            chi.HeterogeneousModel(),403            chi.PooledModel(),404            chi.PooledModel(),405            cov_pop_model]406        problem.set_population_model(pop_models)407        problem.set_log_prior([408            pints.HalfCauchyLogPrior(0, 1)]*9)409        posterior = problem.get_log_posterior()410        self.assertIsInstance(posterior, chi.HierarchicalLogPosterior)411        self.assertEqual(posterior.n_parameters(), 12)412        names = posterior.get_parameter_names()413        ids = posterior.get_id()414        self.assertEqual(len(names), 12)415        self.assertEqual(len(ids), 12)416        self.assertEqual(names[0], 'Pooled myokit.tumour_volume')417        self.assertIsNone(ids[0])418        self.assertEqual(names[1], 'myokit.lambda_0')419        self.assertEqual(ids[1], 'ID 0')420        self.assertEqual(names[2], 'myokit.lambda_0')421        self.assertEqual(ids[2], 'ID 1')422        self.assertEqual(names[3], 'myokit.lambda_0')423        self.assertEqual(ids[3], 'ID 2')424        self.assertEqual(names[4], 'Pooled myokit.lambda_1')425        self.assertIsNone(ids[4])426        self.assertEqual(names[5], 'Pooled Sigma base')427        self.assertIsNone(ids[5])428        self.assertEqual(names[6], 'Sigma rel. Eta')429        self.assertEqual(ids[6], 'ID 0')430        self.assertEqual(names[7], 'Sigma rel. Eta')431        self.assertEqual(ids[7], 'ID 1')432        self.assertEqual(names[8], 'Sigma rel. Eta')433        self.assertEqual(ids[8], 'ID 2')434        self.assertEqual(names[9], 'Base mean log Sigma rel.')435        self.assertIsNone(ids[9])436        self.assertEqual(names[10], 'Std. log Sigma rel.')437        self.assertIsNone(ids[10])438        self.assertEqual(names[11], 'Shift Age Sigma rel.')439        self.assertIsNone(ids[11])440        # Make sure that selecting an individual is ignored for population441        # models442        posterior = problem.get_log_posterior(individual='some individual')443        self.assertIsInstance(posterior, chi.HierarchicalLogPosterior)444        self.assertEqual(posterior.n_parameters(), 12)445        names = posterior.get_parameter_names()446        ids = posterior.get_id()447        self.assertEqual(len(names), 12)448        self.assertEqual(len(ids), 12)449        self.assertEqual(names[0], 'Pooled myokit.tumour_volume')450        self.assertIsNone(ids[0])451        self.assertEqual(names[1], 'myokit.lambda_0')452        self.assertEqual(ids[1], 'ID 0')453        self.assertEqual(names[2], 'myokit.lambda_0')454        self.assertEqual(ids[2], 'ID 1')455        self.assertEqual(names[3], 'myokit.lambda_0')456        self.assertEqual(ids[3], 'ID 2')457        self.assertEqual(names[4], 'Pooled myokit.lambda_1')458        self.assertIsNone(ids[4])459        self.assertEqual(names[5], 'Pooled Sigma base')460        self.assertIsNone(ids[5])461        self.assertEqual(names[6], 'Sigma rel. Eta')462        self.assertEqual(ids[6], 'ID 0')463        self.assertEqual(names[7], 'Sigma rel. Eta')464        self.assertEqual(ids[7], 'ID 1')465        self.assertEqual(names[8], 'Sigma rel. Eta')466        self.assertEqual(ids[8], 'ID 2')467        self.assertEqual(names[9], 'Base mean log Sigma rel.')468        self.assertIsNone(ids[9])469        self.assertEqual(names[10], 'Std. log Sigma rel.')470        self.assertIsNone(ids[10])471        self.assertEqual(names[11], 'Shift Age Sigma rel.')472        self.assertIsNone(ids[11])473    def test_get_log_posteriors_bad_input(self):474        problem = copy.deepcopy(self.pd_problem)475        # No log-prior has been set476        problem.set_data(477            self.data,478            output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})479        with self.assertRaisesRegex(ValueError, 'The log-prior has not'):480            problem.get_log_posterior()481        # The selected individual does not exist482        individual = 'Not existent'483        problem.set_log_prior([pints.HalfCauchyLogPrior(0, 1)]*7)484        with self.assertRaisesRegex(ValueError, 'The individual cannot'):485            problem.get_log_posterior(individual)486    def test_get_n_parameters(self):487        # Test case I: PD model488        # Test case I.1: No population model489        # Test default flag490        problem = copy.deepcopy(self.pd_problem)491        n_parameters = problem.get_n_parameters()492        self.assertEqual(n_parameters, 7)493        # Test exclude population model True494        n_parameters = problem.get_n_parameters(exclude_pop_model=True)495        self.assertEqual(n_parameters, 7)496        # Test exclude bottom-level model True497        n_parameters = problem.get_n_parameters(exclude_bottom_level=True)498        self.assertEqual(n_parameters, 7)499        # Test case I.2: Population model500        pop_models = [501            chi.PooledModel(),502            chi.PooledModel(),503            chi.HeterogeneousModel(),504            chi.PooledModel(),505            chi.PooledModel(),506            chi.LogNormalModel(),507            chi.LogNormalModel()]508        problem.set_population_model(pop_models)509        n_parameters = problem.get_n_parameters()510        self.assertEqual(n_parameters, 8)511        # Test exclude population model True512        n_parameters = problem.get_n_parameters(exclude_pop_model=True)513        self.assertEqual(n_parameters, 7)514        # Test exclude bottom-level model True515        n_parameters = problem.get_n_parameters(exclude_bottom_level=True)516        self.assertEqual(n_parameters, 8)517        # Test case I.3: Set data518        problem.set_data(519            self.data,520            output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})521        n_parameters = problem.get_n_parameters()522        self.assertEqual(n_parameters, 17)523        # Test exclude population model True524        n_parameters = problem.get_n_parameters(exclude_pop_model=True)525        self.assertEqual(n_parameters, 7)526        # Test exclude bottom-level model True527        n_parameters = problem.get_n_parameters(exclude_bottom_level=True)528        self.assertEqual(n_parameters, 11)529        # Test case II: PKPD model530        # Test case II.1: No population model531        # Test default flag532        problem = copy.deepcopy(self.pkpd_problem)533        n_parameters = problem.get_n_parameters()534        self.assertEqual(n_parameters, 11)535        # Test exclude population model True536        n_parameters = problem.get_n_parameters(exclude_pop_model=True)537        self.assertEqual(n_parameters, 11)538        # Test exclude bottom-level model True539        n_parameters = problem.get_n_parameters(exclude_bottom_level=True)540        self.assertEqual(n_parameters, 11)541        # Test case II.2: Population model542        pop_models = [543            chi.PooledModel(),544            chi.PooledModel(),545            chi.HeterogeneousModel(),546            chi.PooledModel(),547            chi.PooledModel(),548            chi.LogNormalModel(),549            chi.LogNormalModel(),550            chi.PooledModel(),551            chi.PooledModel(),552            chi.PooledModel(),553            chi.PooledModel()]554        problem.set_population_model(pop_models)555        n_parameters = problem.get_n_parameters()556        self.assertEqual(n_parameters, 12)557        # Test exclude population model True558        n_parameters = problem.get_n_parameters(exclude_pop_model=True)559        self.assertEqual(n_parameters, 11)560        # Test exclude bottom-level model True561        n_parameters = problem.get_n_parameters(exclude_bottom_level=True)562        self.assertEqual(n_parameters, 12)563        # Test case II.3: Set data564        problem.set_data(565            self.data,566            output_observable_dict={567                'myokit.tumour_volume': 'Tumour volume',568                'central.drug_concentration': 'IL 6'})569        n_parameters = problem.get_n_parameters()570        self.assertEqual(n_parameters, 21)571        # Test exclude population model True572        n_parameters = problem.get_n_parameters(exclude_pop_model=True)573        self.assertEqual(n_parameters, 11)574        # Test exclude bottom-level model True575        n_parameters = problem.get_n_parameters(exclude_bottom_level=True)576        self.assertEqual(n_parameters, 15)577    def test_get_parameter_names(self):578        # Test case I: PD model579        problem = copy.deepcopy(self.pd_problem)580        # Test case I.1: No population model581        # Test default flag582        param_names = problem.get_parameter_names()583        self.assertEqual(len(param_names), 7)584        self.assertEqual(param_names[0], 'myokit.tumour_volume')585        self.assertEqual(param_names[1], 'myokit.drug_concentration')586        self.assertEqual(param_names[2], 'myokit.kappa')587        self.assertEqual(param_names[3], 'myokit.lambda_0')588        self.assertEqual(param_names[4], 'myokit.lambda_1')589        self.assertEqual(param_names[5], 'Sigma base')590        self.assertEqual(param_names[6], 'Sigma rel.')591        # Check that also works with exclude pop params flag592        param_names = problem.get_parameter_names(exclude_pop_model=True)593        self.assertEqual(len(param_names), 7)594        self.assertEqual(param_names[0], 'myokit.tumour_volume')595        self.assertEqual(param_names[1], 'myokit.drug_concentration')596        self.assertEqual(param_names[2], 'myokit.kappa')597        self.assertEqual(param_names[3], 'myokit.lambda_0')598        self.assertEqual(param_names[4], 'myokit.lambda_1')599        self.assertEqual(param_names[5], 'Sigma base')600        self.assertEqual(param_names[6], 'Sigma rel.')601        # Check that also works with exclude bottom-level flag602        param_names = problem.get_parameter_names(exclude_bottom_level=True)603        self.assertEqual(len(param_names), 7)604        self.assertEqual(param_names[0], 'myokit.tumour_volume')605        self.assertEqual(param_names[1], 'myokit.drug_concentration')606        self.assertEqual(param_names[2], 'myokit.kappa')607        self.assertEqual(param_names[3], 'myokit.lambda_0')608        self.assertEqual(param_names[4], 'myokit.lambda_1')609        self.assertEqual(param_names[5], 'Sigma base')610        self.assertEqual(param_names[6], 'Sigma rel.')611        # Test case I.2: Population model612        cov_population_model = chi.CovariatePopulationModel(613            chi.GaussianModel(), chi.LogNormalLinearCovariateModel())614        pop_models = [615            chi.PooledModel(),616            chi.PooledModel(),617            chi.HeterogeneousModel(),618            chi.PooledModel(),619            chi.PooledModel(),620            cov_population_model,621            chi.LogNormalModel()]622        problem.set_population_model(pop_models)623        param_names = problem.get_parameter_names()624        self.assertEqual(len(param_names), 8)625        self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')626        self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')627        self.assertEqual(param_names[2], 'Pooled myokit.lambda_0')628        self.assertEqual(param_names[3], 'Pooled myokit.lambda_1')629        self.assertEqual(param_names[4], 'Base mean log Sigma base')630        self.assertEqual(param_names[5], 'Std. log Sigma base')631        self.assertEqual(param_names[6], 'Mean log Sigma rel.')632        self.assertEqual(param_names[7], 'Std. log Sigma rel.')633        # Test exclude population model True634        param_names = problem.get_parameter_names(exclude_pop_model=True)635        self.assertEqual(len(param_names), 7)636        self.assertEqual(param_names[0], 'myokit.tumour_volume')637        self.assertEqual(param_names[1], 'myokit.drug_concentration')638        self.assertEqual(param_names[2], 'myokit.kappa')639        self.assertEqual(param_names[3], 'myokit.lambda_0')640        self.assertEqual(param_names[4], 'myokit.lambda_1')641        self.assertEqual(param_names[5], 'Sigma base')642        self.assertEqual(param_names[6], 'Sigma rel.')643        # Test exclude bottom-level True644        param_names = problem.get_parameter_names(exclude_bottom_level=True)645        self.assertEqual(len(param_names), 8)646        self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')647        self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')648        self.assertEqual(param_names[2], 'Pooled myokit.lambda_0')649        self.assertEqual(param_names[3], 'Pooled myokit.lambda_1')650        self.assertEqual(param_names[4], 'Base mean log Sigma base')651        self.assertEqual(param_names[5], 'Std. log Sigma base')652        self.assertEqual(param_names[6], 'Mean log Sigma rel.')653        self.assertEqual(param_names[7], 'Std. log Sigma rel.')654        # Test case I.3: Set data655        problem.set_data(656            self.data,657            output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})658        param_names = problem.get_parameter_names()659        self.assertEqual(len(param_names), 17)660        self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')661        self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')662        self.assertEqual(param_names[2], 'ID 0: myokit.kappa')663        self.assertEqual(param_names[3], 'ID 1: myokit.kappa')664        self.assertEqual(param_names[4], 'ID 2: myokit.kappa')665        self.assertEqual(param_names[5], 'Pooled myokit.lambda_0')666        self.assertEqual(param_names[6], 'Pooled myokit.lambda_1')667        self.assertEqual(param_names[7], 'ID 0: Sigma base Eta')668        self.assertEqual(param_names[8], 'ID 1: Sigma base Eta')669        self.assertEqual(param_names[9], 'ID 2: Sigma base Eta')670        self.assertEqual(param_names[10], 'Base mean log Sigma base')671        self.assertEqual(param_names[11], 'Std. log Sigma base')672        self.assertEqual(param_names[12], 'ID 0: Sigma rel.')673        self.assertEqual(param_names[13], 'ID 1: Sigma rel.')674        self.assertEqual(param_names[14], 'ID 2: Sigma rel.')675        self.assertEqual(param_names[15], 'Mean log Sigma rel.')676        self.assertEqual(param_names[16], 'Std. log Sigma rel.')677        # Test exclude population model True678        param_names = problem.get_parameter_names(exclude_pop_model=True)679        self.assertEqual(len(param_names), 7)680        self.assertEqual(param_names[0], 'myokit.tumour_volume')681        self.assertEqual(param_names[1], 'myokit.drug_concentration')682        self.assertEqual(param_names[2], 'myokit.kappa')683        self.assertEqual(param_names[3], 'myokit.lambda_0')684        self.assertEqual(param_names[4], 'myokit.lambda_1')685        self.assertEqual(param_names[5], 'Sigma base')686        self.assertEqual(param_names[6], 'Sigma rel.')687        # Test exclude bottom-level True688        param_names = problem.get_parameter_names(exclude_bottom_level=True)689        self.assertEqual(len(param_names), 11)690        self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')691        self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')692        self.assertEqual(param_names[2], 'ID 0: myokit.kappa')693        self.assertEqual(param_names[3], 'ID 1: myokit.kappa')694        self.assertEqual(param_names[4], 'ID 2: myokit.kappa')695        self.assertEqual(param_names[5], 'Pooled myokit.lambda_0')696        self.assertEqual(param_names[6], 'Pooled myokit.lambda_1')697        self.assertEqual(param_names[7], 'Base mean log Sigma base')698        self.assertEqual(param_names[8], 'Std. log Sigma base')699        self.assertEqual(param_names[9], 'Mean log Sigma rel.')700        self.assertEqual(param_names[10], 'Std. log Sigma rel.')701        # Test case II: PKPD model702        problem = copy.deepcopy(self.pkpd_problem)703        # Test case II.1: No population model704        # Test default flag705        param_names = problem.get_parameter_names()706        self.assertEqual(len(param_names), 11)707        self.assertEqual(param_names[0], 'central.drug_amount')708        self.assertEqual(param_names[1], 'myokit.tumour_volume')709        self.assertEqual(param_names[2], 'central.size')710        self.assertEqual(param_names[3], 'myokit.critical_volume')711        self.assertEqual(param_names[4], 'myokit.elimination_rate')712        self.assertEqual(param_names[5], 'myokit.kappa')713        self.assertEqual(param_names[6], 'myokit.lambda')714        self.assertEqual(715            param_names[7], 'central.drug_concentration Sigma base')716        self.assertEqual(717            param_names[8], 'central.drug_concentration Sigma rel.')718        self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')719        self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')720        # Test exclude population model True721        param_names = problem.get_parameter_names(exclude_pop_model=True)722        self.assertEqual(len(param_names), 11)723        self.assertEqual(param_names[0], 'central.drug_amount')724        self.assertEqual(param_names[1], 'myokit.tumour_volume')725        self.assertEqual(param_names[2], 'central.size')726        self.assertEqual(param_names[3], 'myokit.critical_volume')727        self.assertEqual(param_names[4], 'myokit.elimination_rate')728        self.assertEqual(param_names[5], 'myokit.kappa')729        self.assertEqual(param_names[6], 'myokit.lambda')730        self.assertEqual(731            param_names[7], 'central.drug_concentration Sigma base')732        self.assertEqual(733            param_names[8], 'central.drug_concentration Sigma rel.')734        self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')735        self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')736        # Test exclude population model True737        param_names = problem.get_parameter_names(exclude_bottom_level=True)738        self.assertEqual(len(param_names), 11)739        self.assertEqual(param_names[0], 'central.drug_amount')740        self.assertEqual(param_names[1], 'myokit.tumour_volume')741        self.assertEqual(param_names[2], 'central.size')742        self.assertEqual(param_names[3], 'myokit.critical_volume')743        self.assertEqual(param_names[4], 'myokit.elimination_rate')744        self.assertEqual(param_names[5], 'myokit.kappa')745        self.assertEqual(param_names[6], 'myokit.lambda')746        self.assertEqual(747            param_names[7], 'central.drug_concentration Sigma base')748        self.assertEqual(749            param_names[8], 'central.drug_concentration Sigma rel.')750        self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')751        self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')752        # Test case II.2: Population model753        cov_pop_model = chi.CovariatePopulationModel(754            chi.GaussianModel(),755            chi.LogNormalLinearCovariateModel(n_covariates=1)756        )757        cov_pop_model.set_covariate_names(['Age'], True)758        pop_models = [759            chi.PooledModel(),760            chi.PooledModel(),761            chi.HeterogeneousModel(),762            chi.PooledModel(),763            chi.PooledModel(),764            chi.LogNormalModel(),765            chi.LogNormalModel(),766            chi.PooledModel(),767            cov_pop_model,768            chi.PooledModel(),769            chi.PooledModel()]770        problem.set_population_model(pop_models)771        param_names = problem.get_parameter_names()772        self.assertEqual(len(param_names), 14)773        self.assertEqual(param_names[0], 'Pooled central.drug_amount')774        self.assertEqual(param_names[1], 'Pooled myokit.tumour_volume')775        self.assertEqual(param_names[2], 'Pooled myokit.critical_volume')776        self.assertEqual(param_names[3], 'Pooled myokit.elimination_rate')777        self.assertEqual(param_names[4], 'Mean log myokit.kappa')778        self.assertEqual(param_names[5], 'Std. log myokit.kappa')779        self.assertEqual(param_names[6], 'Mean log myokit.lambda')780        self.assertEqual(param_names[7], 'Std. log myokit.lambda')781        self.assertEqual(782            param_names[8], 'Pooled central.drug_concentration Sigma base')783        self.assertEqual(784            param_names[9],785            'Base mean log central.drug_concentration Sigma rel.')786        self.assertEqual(787            param_names[10], 'Std. log central.drug_concentration Sigma rel.')788        self.assertEqual(789            param_names[11], 'Shift Age central.drug_concentration Sigma rel.')790        self.assertEqual(791            param_names[12], 'Pooled myokit.tumour_volume Sigma base')792        self.assertEqual(793            param_names[13], 'Pooled myokit.tumour_volume Sigma rel.')794        # Test exclude population model True795        param_names = problem.get_parameter_names(exclude_pop_model=True)796        self.assertEqual(len(param_names), 11)797        self.assertEqual(param_names[0], 'central.drug_amount')798        self.assertEqual(param_names[1], 'myokit.tumour_volume')799        self.assertEqual(param_names[2], 'central.size')800        self.assertEqual(param_names[3], 'myokit.critical_volume')801        self.assertEqual(param_names[4], 'myokit.elimination_rate')802        self.assertEqual(param_names[5], 'myokit.kappa')803        self.assertEqual(param_names[6], 'myokit.lambda')804        self.assertEqual(805            param_names[7], 'central.drug_concentration Sigma base')806        self.assertEqual(807            param_names[8], 'central.drug_concentration Sigma rel.')808        self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')809        self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')810        # Test exclude bottom-level True811        param_names = problem.get_parameter_names(exclude_bottom_level=True)812        self.assertEqual(len(param_names), 14)813        self.assertEqual(param_names[0], 'Pooled central.drug_amount')814        self.assertEqual(param_names[1], 'Pooled myokit.tumour_volume')815        self.assertEqual(param_names[2], 'Pooled myokit.critical_volume')816        self.assertEqual(param_names[3], 'Pooled myokit.elimination_rate')817        self.assertEqual(param_names[4], 'Mean log myokit.kappa')818        self.assertEqual(param_names[5], 'Std. log myokit.kappa')819        self.assertEqual(param_names[6], 'Mean log myokit.lambda')820        self.assertEqual(param_names[7], 'Std. log myokit.lambda')821        self.assertEqual(822            param_names[8], 'Pooled central.drug_concentration Sigma base')823        self.assertEqual(824            param_names[9],825            'Base mean log central.drug_concentration Sigma rel.')826        self.assertEqual(827            param_names[10], 'Std. log central.drug_concentration Sigma rel.')828        self.assertEqual(829            param_names[11], 'Shift Age central.drug_concentration Sigma rel.')830        self.assertEqual(831            param_names[12], 'Pooled myokit.tumour_volume Sigma base')832        self.assertEqual(833            param_names[13], 'Pooled myokit.tumour_volume Sigma rel.')834        # Test case II.3: Set data835        problem.set_data(836            self.data,837            output_observable_dict={838                'myokit.tumour_volume': 'Tumour volume',839                'central.drug_concentration': 'IL 6'})840        param_names = problem.get_parameter_names()841        self.assertEqual(len(param_names), 26)842        self.assertEqual(param_names[0], 'Pooled central.drug_amount')843        self.assertEqual(param_names[1], 'Pooled myokit.tumour_volume')844        self.assertEqual(param_names[2], 'ID 0: central.size')845        self.assertEqual(param_names[3], 'ID 1: central.size')846        self.assertEqual(param_names[4], 'ID 2: central.size')847        self.assertEqual(param_names[5], 'Pooled myokit.critical_volume')848        self.assertEqual(param_names[6], 'Pooled myokit.elimination_rate')849        self.assertEqual(param_names[7], 'ID 0: myokit.kappa')850        self.assertEqual(param_names[8], 'ID 1: myokit.kappa')851        self.assertEqual(param_names[9], 'ID 2: myokit.kappa')852        self.assertEqual(param_names[10], 'Mean log myokit.kappa')853        self.assertEqual(param_names[11], 'Std. log myokit.kappa')854        self.assertEqual(param_names[12], 'ID 0: myokit.lambda')855        self.assertEqual(param_names[13], 'ID 1: myokit.lambda')856        self.assertEqual(param_names[14], 'ID 2: myokit.lambda')857        self.assertEqual(param_names[15], 'Mean log myokit.lambda')858        self.assertEqual(param_names[16], 'Std. log myokit.lambda')859        self.assertEqual(860            param_names[17], 'Pooled central.drug_concentration Sigma base')861        self.assertEqual(862            param_names[18],863            'ID 0: central.drug_concentration Sigma rel. Eta')864        self.assertEqual(865            param_names[19],866            'ID 1: central.drug_concentration Sigma rel. Eta')867        self.assertEqual(868            param_names[20],869            'ID 2: central.drug_concentration Sigma rel. Eta')870        self.assertEqual(871            param_names[21],872            'Base mean log central.drug_concentration Sigma rel.')873        self.assertEqual(874            param_names[22], 'Std. log central.drug_concentration Sigma rel.')875        self.assertEqual(876            param_names[23], 'Shift Age central.drug_concentration Sigma rel.')877        self.assertEqual(878            param_names[24], 'Pooled myokit.tumour_volume Sigma base')879        self.assertEqual(880            param_names[25], 'Pooled myokit.tumour_volume Sigma rel.')881        # Test exclude population model True882        param_names = problem.get_parameter_names(exclude_pop_model=True)883        self.assertEqual(len(param_names), 11)884        self.assertEqual(param_names[0], 'central.drug_amount')885        self.assertEqual(param_names[1], 'myokit.tumour_volume')886        self.assertEqual(param_names[2], 'central.size')887        self.assertEqual(param_names[3], 'myokit.critical_volume')888        self.assertEqual(param_names[4], 'myokit.elimination_rate')889        self.assertEqual(param_names[5], 'myokit.kappa')890        self.assertEqual(param_names[6], 'myokit.lambda')891        self.assertEqual(892            param_names[7], 'central.drug_concentration Sigma base')893        self.assertEqual(894            param_names[8], 'central.drug_concentration Sigma rel.')895        self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')896        self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')897        # Test exclude bottom-level True898        param_names = problem.get_parameter_names(exclude_bottom_level=True)899        self.assertEqual(len(param_names), 17)900        self.assertEqual(param_names[0], 'Pooled central.drug_amount')901        self.assertEqual(param_names[1], 'Pooled myokit.tumour_volume')902        self.assertEqual(param_names[2], 'ID 0: central.size')903        self.assertEqual(param_names[3], 'ID 1: central.size')904        self.assertEqual(param_names[4], 'ID 2: central.size')905        self.assertEqual(param_names[5], 'Pooled myokit.critical_volume')906        self.assertEqual(param_names[6], 'Pooled myokit.elimination_rate')907        self.assertEqual(param_names[7], 'Mean log myokit.kappa')908        self.assertEqual(param_names[8], 'Std. log myokit.kappa')909        self.assertEqual(param_names[9], 'Mean log myokit.lambda')910        self.assertEqual(param_names[10], 'Std. log myokit.lambda')911        self.assertEqual(912            param_names[11], 'Pooled central.drug_concentration Sigma base')913        self.assertEqual(914            param_names[12],915            'Base mean log central.drug_concentration Sigma rel.')916        self.assertEqual(917            param_names[13], 'Std. log central.drug_concentration Sigma rel.')918        self.assertEqual(919            param_names[14], 'Shift Age central.drug_concentration Sigma rel.')920        self.assertEqual(921            param_names[15], 'Pooled myokit.tumour_volume Sigma base')922        self.assertEqual(923            param_names[16], 'Pooled myokit.tumour_volume Sigma rel.')924    def test_get_predictive_model(self):925        # Test case I: PD model926        problem = copy.deepcopy(self.pd_problem)927        # Test case I.1: No population model928        predictive_model = problem.get_predictive_model()929        self.assertIsInstance(predictive_model, chi.PredictiveModel)930        # Exclude population model931        predictive_model = problem.get_predictive_model(932            exclude_pop_model=True)933        self.assertIsInstance(predictive_model, chi.PredictiveModel)934        # Test case I.2: Population model935        problem.set_population_model([936            chi.PooledModel(),937            chi.PooledModel(),938            chi.HeterogeneousModel(),939            chi.PooledModel(),940            chi.PooledModel(),941            chi.LogNormalModel(),942            chi.LogNormalModel()])943        predictive_model = problem.get_predictive_model()944        self.assertIsInstance(945            predictive_model, chi.PopulationPredictiveModel)946        # Exclude population model947        predictive_model = problem.get_predictive_model(948            exclude_pop_model=True)949        self.assertNotIsInstance(950            predictive_model, chi.PopulationPredictiveModel)951        self.assertIsInstance(predictive_model, chi.PredictiveModel)952        # Test case II: PKPD model953        problem = copy.deepcopy(self.pkpd_problem)954        # Test case II.1: No population model955        predictive_model = problem.get_predictive_model()956        self.assertIsInstance(predictive_model, chi.PredictiveModel)957        # Exclude population model958        predictive_model = problem.get_predictive_model(959            exclude_pop_model=True)960        self.assertIsInstance(predictive_model, chi.PredictiveModel)961        # Test case II.2: Population model962        problem.set_population_model([963            chi.PooledModel(),964            chi.PooledModel(),965            chi.HeterogeneousModel(),966            chi.PooledModel(),967            chi.PooledModel(),968            chi.LogNormalModel(),969            chi.LogNormalModel(),970            chi.PooledModel(),971            chi.PooledModel(),972            chi.PooledModel(),973            chi.PooledModel()])974        predictive_model = problem.get_predictive_model()975        self.assertIsInstance(976            predictive_model, chi.PopulationPredictiveModel)977        # Exclude population model978        predictive_model = problem.get_predictive_model(979            exclude_pop_model=True)980        self.assertNotIsInstance(981            predictive_model, chi.PopulationPredictiveModel)982        self.assertIsInstance(predictive_model, chi.PredictiveModel)983    def test_set_data(self):984        # Set data with explicit output-observable map985        problem = copy.deepcopy(self.pd_problem)986        output_observable_dict = {'myokit.tumour_volume': 'Tumour volume'}987        problem.set_data(self.data, output_observable_dict)988        # Set data with implicit output-observable map989        mask = self.data['Observable'] == 'Tumour volume'990        problem.set_data(self.data[mask])991        # Set data with explicit covariate mapping992        cov_pop_model = chi.CovariatePopulationModel(993            chi.GaussianModel(),994            chi.LogNormalLinearCovariateModel(n_covariates=1)995        )996        cov_pop_model.set_covariate_names(['Sex'], True)997        pop_models = [cov_pop_model] * 7998        problem.set_population_model(pop_models)999        covariate_dict = {'Sex': 'Age'}1000        problem.set_data(self.data, output_observable_dict, covariate_dict)1001    def test_set_data_bad_input(self):1002        # Data has the wrong type1003        data = 'Wrong type'1004        with self.assertRaisesRegex(TypeError, 'Data has to be a'):1005            self.pd_problem.set_data(data)1006        # Data has the wrong ID key1007        data = self.data.rename(columns={'ID': 'Some key'})1008        with self.assertRaisesRegex(ValueError, 'Data does not have the'):1009            self.pkpd_problem.set_data(data)1010        # Data has the wrong time key1011        data = self.data.rename(columns={'Time': 'Some key'})1012        with self.assertRaisesRegex(ValueError, 'Data does not have the'):1013            self.pkpd_problem.set_data(data)1014        # Data has the wrong observable key1015        data = self.data.rename(columns={'Observable': 'Some key'})1016        with self.assertRaisesRegex(ValueError, 'Data does not have the'):1017            self.pkpd_problem.set_data(data)1018        # Data has the wrong value key1019        data = self.data.rename(columns={'Value': 'Some key'})1020        with self.assertRaisesRegex(ValueError, 'Data does not have the'):1021            self.pkpd_problem.set_data(data)1022        # Data has the wrong dose key1023        data = self.data.rename(columns={'Dose': 'Some key'})1024        with self.assertRaisesRegex(ValueError, 'Data does not have the'):1025            self.pkpd_problem.set_data(data)1026        # Data has the wrong duration key1027        data = self.data.rename(columns={'Duration': 'Some key'})1028        with self.assertRaisesRegex(ValueError, 'Data does not have the'):1029            self.pkpd_problem.set_data(data)1030        # The output-observable map does not contain a model output1031        output_observable_dict = {'some output': 'some observable'}1032        with self.assertRaisesRegex(ValueError, 'The output <central.drug'):1033            self.pkpd_problem.set_data(self.data, output_observable_dict)1034        # The output-observable map references a observable that is not in the1035        # dataframe1036        output_observable_dict = {'myokit.tumour_volume': 'some observable'}1037        with self.assertRaisesRegex(ValueError, 'The observable <some'):1038            self.pd_problem.set_data(self.data, output_observable_dict)1039        # The model outputs and dataframe observable cannot be trivially mapped1040        with self.assertRaisesRegex(ValueError, 'The observable <central.'):1041            self.pkpd_problem.set_data(self.data)1042        # Covariate map does not contain all model covariates1043        problem = copy.deepcopy(self.pd_problem)1044        cov_pop_model1 = chi.CovariatePopulationModel(1045            chi.GaussianModel(),1046            chi.LogNormalLinearCovariateModel(n_covariates=1)1047        )1048        cov_pop_model1.set_covariate_names(['Age'], True)1049        cov_pop_model2 = chi.CovariatePopulationModel(1050            chi.GaussianModel(),1051            chi.LogNormalLinearCovariateModel(n_covariates=1)1052        )1053        cov_pop_model2.set_covariate_names(['Sex'], True)1054        pop_models = [cov_pop_model1] * 4 + [cov_pop_model2] * 31055        problem.set_population_model(pop_models)1056        output_observable_dict = {'myokit.tumour_volume': 'Tumour volume'}1057        covariate_dict = {'Age': 'Age', 'Something': 'else'}1058        with self.assertRaisesRegex(ValueError, 'The covariate <Sex> could'):1059            problem.set_data(1060                self.data,1061                output_observable_dict=output_observable_dict,1062                covariate_dict=covariate_dict)1063        # Covariate dict maps to covariate that is not in the dataframe1064        covariate_dict = {'Age': 'Age', 'Sex': 'Does not exist'}1065        with self.assertRaisesRegex(ValueError, 'The covariate <Does not ex'):1066            problem.set_data(1067                self.data,1068                output_observable_dict=output_observable_dict,1069                covariate_dict=covariate_dict)1070        # There are no covariate values provided for an ID1071        data = self.data.copy()1072        mask = (data.ID == 1) | (data.Observable == 'Age')1073        data.loc[mask, 'Value'] = np.nan1074        pop_models = [cov_pop_model1] * 71075        problem.set_population_model(pop_models)1076        with self.assertRaisesRegex(ValueError, 'There are either 0 or more'):1077            problem.set_data(1078                data,1079                output_observable_dict=output_observable_dict)1080        # There is more than one covariate value provided for an ID1081        data = self.data.copy()1082        mask = data.Observable == 'Age'1083        data.loc[mask, 'ID'] = 01084        pop_models = [cov_pop_model1] * 71085        problem.set_population_model(pop_models)1086        with self.assertRaisesRegex(ValueError, 'There are either 0 or more'):1087            problem.set_data(1088                data,1089                output_observable_dict=output_observable_dict)1090    def test_set_log_prior(self):1091        # Test case I: PD model1092        problem = copy.deepcopy(self.pd_problem)1093        problem.set_data(self.data, {'myokit.tumour_volume': 'Tumour volume'})1094        log_priors = [pints.HalfCauchyLogPrior(0, 1)] * 71095        # Map priors to parameters automatically1096        problem.set_log_prior(log_priors)1097        # Specify prior parameter map explicitly1098        param_names = [1099            'myokit.kappa',1100            'Sigma base',1101            'Sigma rel.',1102            'myokit.tumour_volume',1103            'myokit.lambda_1',1104            'myokit.drug_concentration',1105            'myokit.lambda_0']1106        problem.set_log_prior(log_priors, param_names)1107    def test_set_log_prior_bad_input(self):1108        problem = copy.deepcopy(self.pd_problem)1109        # No data has been set1110        with self.assertRaisesRegex(ValueError, 'The data has not'):1111            problem.set_log_prior('some prior')1112        # Wrong log-prior type1113        problem.set_data(self.data, {'myokit.tumour_volume': 'Tumour volume'})1114        log_priors = ['Wrong', 'type']1115        with self.assertRaisesRegex(ValueError, 'All marginal log-priors'):1116            problem.set_log_prior(log_priors)1117        # Number of log priors does not match number of parameters1118        log_priors = [1119            pints.GaussianLogPrior(0, 1), pints.HalfCauchyLogPrior(0, 1)]1120        with self.assertRaisesRegex(ValueError, 'One marginal log-prior'):1121            problem.set_log_prior(log_priors)1122        # Dimensionality of joint log-pior does not match number of parameters1123        prior = pints.ComposedLogPrior(1124            pints.GaussianLogPrior(0, 1), pints.GaussianLogPrior(0, 1))1125        log_priors = [1126            prior,1127            pints.UniformLogPrior(0, 1),1128            pints.UniformLogPrior(0, 1),1129            pints.UniformLogPrior(0, 1),1130            pints.UniformLogPrior(0, 1),1131            pints.UniformLogPrior(0, 1),1132            pints.UniformLogPrior(0, 1)]1133        with self.assertRaisesRegex(ValueError, 'The joint log-prior'):1134            problem.set_log_prior(log_priors)1135        # Specified parameter names do not match the model parameters1136        params = ['wrong', 'params']1137        log_priors = [pints.HalfCauchyLogPrior(0, 1)] * 71138        with self.assertRaisesRegex(ValueError, 'The specified parameter'):1139            problem.set_log_prior(log_priors, params)1140    def test_set_population_model(self):1141        # Test case I: PD model1142        problem = copy.deepcopy(self.pd_problem)1143        problem.set_data(self.data, {'myokit.tumour_volume': 'Tumour volume'})1144        pop_models = [1145            chi.PooledModel(),1146            chi.PooledModel(),1147            chi.HeterogeneousModel(),1148            chi.PooledModel(),1149            chi.PooledModel(),1150            chi.PooledModel(),1151            chi.LogNormalModel()]1152        # Test case I.1: Don't specify order1153        problem.set_population_model(pop_models)1154        self.assertEqual(problem.get_n_parameters(), 13)1155        param_names = problem.get_parameter_names()1156        self.assertEqual(len(param_names), 13)1157        self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')1158        self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')1159        self.assertEqual(param_names[2], 'ID 0: myokit.kappa')1160        self.assertEqual(param_names[3], 'ID 1: myokit.kappa')1161        self.assertEqual(param_names[4], 'ID 2: myokit.kappa')1162        self.assertEqual(param_names[5], 'Pooled myokit.lambda_0')1163        self.assertEqual(param_names[6], 'Pooled myokit.lambda_1')1164        self.assertEqual(param_names[7], 'Pooled Sigma base')1165        self.assertEqual(param_names[8], 'ID 0: Sigma rel.')1166        self.assertEqual(param_names[9], 'ID 1: Sigma rel.')1167        self.assertEqual(param_names[10], 'ID 2: Sigma rel.')1168        self.assertEqual(param_names[11], 'Mean log Sigma rel.')1169        self.assertEqual(param_names[12], 'Std. log Sigma rel.')1170        # Test case I.2: Specify order1171        parameter_names = [1172            'Sigma base',1173            'myokit.drug_concentration',1174            'myokit.lambda_1',1175            'myokit.kappa',1176            'myokit.tumour_volume',1177            'Sigma rel.',1178            'myokit.lambda_0']1179        problem.set_population_model(pop_models, parameter_names)1180        self.assertEqual(problem.get_n_parameters(), 13)1181        param_names = problem.get_parameter_names()1182        self.assertEqual(len(param_names), 13)1183        self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')1184        self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')1185        self.assertEqual(param_names[2], 'Pooled myokit.kappa')1186        self.assertEqual(param_names[3], 'ID 0: myokit.lambda_0')1187        self.assertEqual(param_names[4], 'ID 1: myokit.lambda_0')1188        self.assertEqual(param_names[5], 'ID 2: myokit.lambda_0')1189        self.assertEqual(param_names[6], 'Mean log myokit.lambda_0')1190        self.assertEqual(param_names[7], 'Std. log myokit.lambda_0')1191        self.assertEqual(param_names[8], 'ID 0: myokit.lambda_1')1192        self.assertEqual(param_names[9], 'ID 1: myokit.lambda_1')1193        self.assertEqual(param_names[10], 'ID 2: myokit.lambda_1')1194        self.assertEqual(param_names[11], 'Pooled Sigma base')1195        self.assertEqual(param_names[12], 'Pooled Sigma rel.')1196        # Test case I.3: With covariates1197        cov_pop_model = chi.CovariatePopulationModel(1198            chi.GaussianModel(),1199            chi.LogNormalLinearCovariateModel(n_covariates=1)1200        )1201        cov_pop_model.set_covariate_names(['Age'], True)1202        pop_models = [1203            chi.PooledModel(),1204            chi.PooledModel(),1205            chi.HeterogeneousModel(),1206            chi.PooledModel(),1207            cov_pop_model,1208            chi.PooledModel(),1209            chi.LogNormalModel()]1210        problem.set_population_model(pop_models)1211        self.assertEqual(problem.get_n_parameters(), 18)1212        param_names = problem.get_parameter_names()1213        self.assertEqual(len(param_names), 18)1214        self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')1215        self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')1216        self.assertEqual(param_names[2], 'ID 0: myokit.kappa')1217        self.assertEqual(param_names[3], 'ID 1: myokit.kappa')1218        self.assertEqual(param_names[4], 'ID 2: myokit.kappa')1219        self.assertEqual(param_names[5], 'Pooled myokit.lambda_0')1220        self.assertEqual(param_names[6], 'ID 0: myokit.lambda_1 Eta')1221        self.assertEqual(param_names[7], 'ID 1: myokit.lambda_1 Eta')1222        self.assertEqual(param_names[8], 'ID 2: myokit.lambda_1 Eta')1223        self.assertEqual(param_names[9], 'Base mean log myokit.lambda_1')1224        self.assertEqual(param_names[10], 'Std. log myokit.lambda_1')1225        self.assertEqual(param_names[11], 'Shift Age myokit.lambda_1')1226        self.assertEqual(param_names[12], 'Pooled Sigma base')1227        self.assertEqual(param_names[13], 'ID 0: Sigma rel.')1228        self.assertEqual(param_names[14], 'ID 1: Sigma rel.')1229        self.assertEqual(param_names[15], 'ID 2: Sigma rel.')1230        self.assertEqual(param_names[16], 'Mean log Sigma rel.')1231        self.assertEqual(param_names[17], 'Std. log Sigma rel.')1232        # Test case II: PKPD model1233        problem = copy.deepcopy(self.pkpd_problem)1234        problem.set_data(1235            self.data,1236            output_observable_dict={1237                'central.drug_concentration': 'IL 6',1238                'myokit.tumour_volume': 'Tumour volume'})1239        pop_models = [1240            chi.LogNormalModel(),1241            chi.LogNormalModel(),1242            chi.LogNormalModel(),1243            chi.PooledModel(),1244            chi.PooledModel(),1245            chi.HeterogeneousModel(),1246            chi.PooledModel(),1247            chi.PooledModel(),1248            chi.LogNormalModel(),1249            chi.PooledModel(),1250            chi.LogNormalModel()]1251        # Test case I.1: Don't specify order1252        problem.set_population_model(pop_models)1253        self.assertEqual(problem.get_n_parameters(), 33)1254        param_names = problem.get_parameter_names()1255        self.assertEqual(len(param_names), 33)1256        self.assertEqual(param_names[0], 'ID 0: central.drug_amount')1257        self.assertEqual(param_names[1], 'ID 1: central.drug_amount')1258        self.assertEqual(param_names[2], 'ID 2: central.drug_amount')1259        self.assertEqual(param_names[3], 'Mean log central.drug_amount')1260        self.assertEqual(param_names[4], 'Std. log central.drug_amount')1261        self.assertEqual(param_names[5], 'ID 0: myokit.tumour_volume')1262        self.assertEqual(param_names[6], 'ID 1: myokit.tumour_volume')1263        self.assertEqual(param_names[7], 'ID 2: myokit.tumour_volume')1264        self.assertEqual(param_names[8], 'Mean log myokit.tumour_volume')1265        self.assertEqual(param_names[9], 'Std. log myokit.tumour_volume')1266        self.assertEqual(param_names[10], 'ID 0: central.size')1267        self.assertEqual(param_names[11], 'ID 1: central.size')1268        self.assertEqual(param_names[12], 'ID 2: central.size')1269        self.assertEqual(param_names[13], 'Mean log central.size')1270        self.assertEqual(param_names[14], 'Std. log central.size')1271        self.assertEqual(param_names[15], 'Pooled myokit.critical_volume')1272        self.assertEqual(param_names[16], 'Pooled myokit.elimination_rate')1273        self.assertEqual(param_names[17], 'ID 0: myokit.kappa')1274        self.assertEqual(param_names[18], 'ID 1: myokit.kappa')1275        self.assertEqual(param_names[19], 'ID 2: myokit.kappa')1276        self.assertEqual(param_names[20], 'Pooled myokit.lambda')1277        self.assertEqual(1278            param_names[21], 'Pooled central.drug_concentration Sigma base')1279        self.assertEqual(1280            param_names[22], 'ID 0: central.drug_concentration Sigma rel.')1281        self.assertEqual(1282            param_names[23], 'ID 1: central.drug_concentration Sigma rel.')1283        self.assertEqual(1284            param_names[24], 'ID 2: central.drug_concentration Sigma rel.')1285        self.assertEqual(1286            param_names[25], 'Mean log central.drug_concentration Sigma rel.')1287        self.assertEqual(1288            param_names[26], 'Std. log central.drug_concentration Sigma rel.')1289        self.assertEqual(1290            param_names[27], 'Pooled myokit.tumour_volume Sigma base')1291        self.assertEqual(1292            param_names[28], 'ID 0: myokit.tumour_volume Sigma rel.')1293        self.assertEqual(1294            param_names[29], 'ID 1: myokit.tumour_volume Sigma rel.')1295        self.assertEqual(1296            param_names[30], 'ID 2: myokit.tumour_volume Sigma rel.')1297        self.assertEqual(1298            param_names[31], 'Mean log myokit.tumour_volume Sigma rel.')1299        self.assertEqual(1300            param_names[32], 'Std. log myokit.tumour_volume Sigma rel.')1301    def test_set_population_model_bad_input(self):1302        # Population models have the wrong type1303        pop_models = ['bad', 'type']1304        with self.assertRaisesRegex(TypeError, 'The population models'):1305            self.pd_problem.set_population_model(pop_models)1306        # Number of population models is not correct1307        pop_models = [chi.PooledModel()]1308        with self.assertRaisesRegex(ValueError, 'The number of population'):1309            self.pd_problem.set_population_model(pop_models)1310        # Specified parameter names do not coincide with model1311        pop_models = [chi.PooledModel()] * 71312        parameter_names = ['wrong names'] * 71313        with self.assertRaisesRegex(ValueError, 'The parameter names'):1314            self.pd_problem.set_population_model(pop_models, parameter_names)1315        # User is warned that data is reset as a result of unclear covariate1316        # mapping1317        self.pd_problem.set_data(1318            self.data,1319            output_observable_dict={1320                'central.drug_concentration': 'IL 6',1321                'myokit.tumour_volume': 'Tumour volume'})1322        cov_pop_model = chi.CovariatePopulationModel(1323            chi.GaussianModel(),1324            chi.LogNormalLinearCovariateModel(n_covariates=1)1325        )1326        pop_models = [cov_pop_model] * 71327        with self.assertWarns(UserWarning):1328            self.pd_problem.set_population_model(pop_models)1329class TestInverseProblem(unittest.TestCase):1330    """1331    Tests the chi.InverseProblem class.1332    """1333    @classmethod1334    def setUpClass(cls):1335        # Create test data1336        cls.times = [1, 2, 3, 4, 5]1337        cls.values = [1, 2, 3, 4, 5]1338        # Set up inverse problem1339        path = ModelLibrary().tumour_growth_inhibition_model_koch()1340        cls.model = chi.PharmacodynamicModel(path)1341        cls.problem = chi.InverseProblem(cls.model, cls.times, cls.values)1342    def test_bad_model_input(self):1343        model = 'bad model'1344        with self.assertRaisesRegex(ValueError, 'Model has to be an instance'):1345            chi.InverseProblem(model, self.times, self.values)1346    def test_bad_times_input(self):1347        times = [-1, 2, 3, 4, 5]1348        with self.assertRaisesRegex(ValueError, 'Times cannot be negative.'):1349            chi.InverseProblem(self.model, times, self.values)1350        times = [5, 4, 3, 2, 1]1351        with self.assertRaisesRegex(ValueError, 'Times must be increasing.'):1352            chi.InverseProblem(self.model, times, self.values)1353    def test_bad_values_input(self):1354        values = [1, 2, 3, 4, 5, 6, 7]1355        with self.assertRaisesRegex(ValueError, 'Values array must have'):1356            chi.InverseProblem(self.model, self.times, values)1357        values = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]1358        with self.assertRaisesRegex(ValueError, 'Values array must have'):1359            chi.InverseProblem(self.model, self.times, values)1360    def test_evaluate(self):1361        parameters = [0.1, 1, 1, 1, 1]1362        output = self.problem.evaluate(parameters)1363        n_times = 51364        n_outputs = 11365        self.assertEqual(output.shape, (n_times, n_outputs))1366    def test_evaluateS1(self):1367        parameters = [0.1, 1, 1, 1, 1]1368        with self.assertRaises(NotImplementedError):1369            self.problem.evaluateS1(parameters)1370    def test_n_ouputs(self):1371        self.assertEqual(self.problem.n_outputs(), 1)1372    def test_n_parameters(self):1373        self.assertEqual(self.problem.n_parameters(), 5)1374    def test_n_times(self):1375        n_times = len(self.times)1376        self.assertEqual(self.problem.n_times(), n_times)1377    def test_times(self):1378        times = self.problem.times()1379        n_times = len(times)1380        self.assertEqual(n_times, 5)1381        self.assertEqual(times[0], self.times[0])1382        self.assertEqual(times[1], self.times[1])1383        self.assertEqual(times[2], self.times[2])1384        self.assertEqual(times[3], self.times[3])1385        self.assertEqual(times[4], self.times[4])1386    def test_values(self):1387        values = self.problem.values()1388        n_times = 51389        n_outputs = 11390        self.assertEqual(values.shape, (n_times, n_outputs))1391        self.assertEqual(values[0], self.values[0])1392        self.assertEqual(values[1], self.values[1])1393        self.assertEqual(values[2], self.values[2])1394        self.assertEqual(values[3], self.values[3])1395        self.assertEqual(values[4], self.values[4])1396if __name__ == '__main__':...track_model_train.py
Source:track_model_train.py  
1from __future__ import absolute_import, division, print_function2import numpy as np3import caffe4from caffe import layers as L5from caffe import params as P6channel_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32)7###############################################################################8# Helper Methods9###############################################################################10def conv_relu(bottom, nout, ks=3, stride=1, pad=1, param_names=('conv_w', 'conv_b'), bias_term=True, fix_param=False, finetune=False):11    if fix_param:12        mult = [dict(name=param_names[0], lr_mult=0, decay_mult=0), dict(name=param_names[1], lr_mult=0, decay_mult=0)]13        conv = L.Convolution(bottom, kernel_size=ks, stride=stride,14                             num_output=nout, pad=pad, param=mult)15    else:16        if finetune:17            mult = [dict(name=param_names[0], lr_mult=0.1, decay_mult=1), dict(name=param_names[1], lr_mult=0.2, decay_mult=0)]18            conv = L.Convolution(bottom, kernel_size=ks, stride=stride,19                                 num_output=nout, pad=pad, param=mult)20        else:21            mult = [dict(name=param_names[0], lr_mult=1, decay_mult=1), dict(name=param_names[1], lr_mult=2, decay_mult=0)]22            filler = dict(type='xavier')23            conv = L.Convolution(bottom, kernel_size=ks, stride=stride,24                                 num_output=nout, pad=pad, bias_term=bias_term,25                                 param=mult, weight_filler=filler)26    return conv, L.ReLU(conv, in_place=True)27def conv(bottom, nout, ks=3, stride=1, pad=1, param_names=('conv_w', 'conv_b'), bias_term=True, fix_param=False, finetune=False):28    if fix_param:29        mult = [dict(name=param_names[0], lr_mult=0, decay_mult=0), dict(name=param_names[1], lr_mult=0, decay_mult=0)]30        conv = L.Convolution(bottom, kernel_size=ks, stride=stride,31                             num_output=nout, pad=pad, param=mult)32    else:33        if finetune:34            mult = [dict(name=param_names[0], lr_mult=0.1, decay_mult=1), dict(name=param_names[1], lr_mult=0.2, decay_mult=0)]35            conv = L.Convolution(bottom, kernel_size=ks, stride=stride,36                                 num_output=nout, pad=pad, param=mult)37        else:38            mult = [dict(name=param_names[0], lr_mult=1, decay_mult=1), dict(name=param_names[1], lr_mult=2, decay_mult=0)]39            filler = dict(type='xavier')40            conv = L.Convolution(bottom, kernel_size=ks, stride=stride,41                                 num_output=nout, pad=pad, bias_term=bias_term,42                                 param=mult, weight_filler=filler)43    return conv44def max_pool(bottom, ks=2, stride=2):45    return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride)46################################################################################47# Model Generation48###############################################################################49def generate_scores(split, config):50    n = caffe.NetSpec()51    dataset = config.dataset52    batch_size = config.N53    mode_str = str(dict(dataset=dataset, split=split, batch_size=batch_size))54    n.image1, n.image2, n.label, n.sample_weights, n.feat_crop = L.Python(module=config.data_provider,55                                                                          layer=config.data_provider_layer,56                                                                          param_str=mode_str,57                                                                          ntop=5)58    ################################59    # the base net (VGG-16) branch 160    n.conv1_1, n.relu1_1 = conv_relu(n.image1, 64,61                                     param_names=('conv1_1_w', 'conv1_1_b'),62                                     fix_param=True,63                                     finetune=False)64    n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64,65                                     param_names=('conv1_2_w', 'conv1_2_b'),66                                     fix_param=True,67                                     finetune=False)68    n.pool1 = max_pool(n.relu1_2)69    n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128,70                                     param_names=('conv2_1_w', 'conv2_1_b'),71                                     fix_param=True,72                                     finetune=False)73    n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128,74                                     param_names=('conv2_2_w', 'conv2_2_b'),75                                     fix_param=True,76                                     finetune=False)77    n.pool2 = max_pool(n.relu2_2)78    n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256,79                                     param_names=('conv3_1_w', 'conv3_1_b'),80                                     fix_param=config.fix_vgg,81                                     finetune=config.finetune)82    n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256,83                                     param_names=('conv3_2_w', 'conv3_2_b'),84                                     fix_param=config.fix_vgg,85                                     finetune=config.finetune)86    n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256,87                                     param_names=('conv3_3_w', 'conv3_3_b'),88                                     fix_param=config.fix_vgg,89                                     finetune=config.finetune)90    n.pool3 = max_pool(n.relu3_3)91    # spatial L2 norm92    n.pool3_lrn = L.LRN(n.pool3, local_size=513, alpha=513, beta=0.5, k=1e-16)93    n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512,94                                     param_names=('conv4_1_w', 'conv4_1_b'),95                                     fix_param=config.fix_vgg,96                                     finetune=config.finetune)97    n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512,98                                     param_names=('conv4_2_w', 'conv4_2_b'),99                                     fix_param=config.fix_vgg,100                                     finetune=config.finetune)101    n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512,102                                     param_names=('conv4_3_w', 'conv4_3_b'),103                                     fix_param=config.fix_vgg,104                                     finetune=config.finetune)105    # spatial L2 norm106    n.relu4_3_lrn = L.LRN(n.relu4_3, local_size=1025, alpha=1025, beta=0.5, k=1e-16)107    #n.pool4 = max_pool(n.relu4_3)108    #n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512,109    #                                 param_names=('conv5_1_w', 'conv5_1_b'),110    #                                 fix_param=config.fix_vgg,111    #                                 finetune=config.finetune)112    #n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512,113    #                                 param_names=('conv5_2_w', 'conv5_2_b'),114    #                                 fix_param=config.fix_vgg,115    #                                 finetune=config.finetune)116    #n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512,117    #                                 param_names=('conv5_3_w', 'conv5_3_b'),118    #                                 fix_param=config.fix_vgg,119    #                                 finetune=config.finetune)120    # upsampling feature map121    #n.relu5_3_upsampling = L.Deconvolution(n.relu5_3,122    #                                       convolution_param=dict(num_output=512,123    #                                                              group=512,124    #                                                              kernel_size=4,125    #                                                              stride=2,126    #                                                              pad=1,127    #                                                              bias_term=False,128    #                                                              weight_filler=dict(type='bilinear')),129    #                                       param=[dict(lr_mult=0, decay_mult=0)])130    # spatial L2 norm131    #n.relu5_3_lrn = L.LRN(n.relu5_3_upsampling, local_size=1025, alpha=1025, beta=0.5, k=1e-16)132    # concat all skip features133    #n.feat_all1 = n.relu4_3_lrn134    n.feat_all1 = L.Concat(n.pool3_lrn, n.relu4_3_lrn, concat_param=dict(axis=1))135    #n.feat_all1 = L.Concat(n.pool3_lrn, n.relu4_3_lrn, n.relu5_3_lrn, concat_param=dict(axis=1))136    n.feat_all1_crop = L.Crop(n.feat_all1, n.feat_crop, crop_param=dict(axis=2, offset=[config.query_featmap_H//3, config.query_featmap_W//3]))137    138    ################################139    # the base net (VGG-16) branch 2140    n.conv1_1_p, n.relu1_1_p = conv_relu(n.image2, 64,141                                         param_names=('conv1_1_w', 'conv1_1_b'),142                                         fix_param=True,143                                         finetune=False)144    n.conv1_2_p, n.relu1_2_p = conv_relu(n.relu1_1_p, 64,145                                         param_names=('conv1_2_w', 'conv1_2_b'),146                                         fix_param=True,147                                         finetune=False)148    n.pool1_p = max_pool(n.relu1_2_p)149    n.conv2_1_p, n.relu2_1_p = conv_relu(n.pool1_p, 128,150                                         param_names=('conv2_1_w', 'conv2_1_b'),151                                         fix_param=True,152                                         finetune=False)153    n.conv2_2_p, n.relu2_2_p = conv_relu(n.relu2_1_p, 128,154                                         param_names=('conv2_2_w', 'conv2_2_b'),155                                         fix_param=True,156                                         finetune=False)157    n.pool2_p = max_pool(n.relu2_2_p)158    n.conv3_1_p, n.relu3_1_p = conv_relu(n.pool2_p, 256,159                                         param_names=('conv3_1_w', 'conv3_1_b'),160                                         fix_param=config.fix_vgg,161                                         finetune=config.finetune)162    n.conv3_2_p, n.relu3_2_p = conv_relu(n.relu3_1_p, 256,163                                         param_names=('conv3_2_w', 'conv3_2_b'),164                                         fix_param=config.fix_vgg,165                                         finetune=config.finetune)166    n.conv3_3_p, n.relu3_3_p = conv_relu(n.relu3_2_p, 256,167                                         param_names=('conv3_3_w', 'conv3_3_b'),168                                         fix_param=config.fix_vgg,169                                         finetune=config.finetune)170    n.pool3_p = max_pool(n.relu3_3_p)171    # spatial L2 norm172    n.pool3_lrn_p = L.LRN(n.pool3_p, local_size=513, alpha=513, beta=0.5, k=1e-16)173    n.conv4_1_p, n.relu4_1_p = conv_relu(n.pool3_p, 512,174                                         param_names=('conv4_1_w', 'conv4_1_b'),175                                         fix_param=config.fix_vgg,176                                         finetune=config.finetune)177    n.conv4_2_p, n.relu4_2_p = conv_relu(n.relu4_1_p, 512,178                                         param_names=('conv4_2_w', 'conv4_2_b'),179                                         fix_param=config.fix_vgg,180                                         finetune=config.finetune)181    n.conv4_3_p, n.relu4_3_p = conv_relu(n.relu4_2_p, 512,182                                         param_names=('conv4_3_w', 'conv4_3_b'),183                                         fix_param=config.fix_vgg,184                                         finetune=config.finetune)185    # spatial L2 norm186    n.relu4_3_lrn_p = L.LRN(n.relu4_3_p, local_size=1025, alpha=1025, beta=0.5, k=1e-16)187    #n.pool4_p = max_pool(n.relu4_3_p)188    #n.conv5_1_p, n.relu5_1_p = conv_relu(n.pool4_p, 512,189    #                                     param_names=('conv5_1_w', 'conv5_1_b'),190    #                                     fix_param=config.fix_vgg,191    #                                     finetune=config.finetune)192    #n.conv5_2_p, n.relu5_2_p = conv_relu(n.relu5_1_p, 512,193    #                                     param_names=('conv5_2_w', 'conv5_2_b'),194    #                                     fix_param=config.fix_vgg,195    #                                     finetune=config.finetune)196    #n.conv5_3_p, n.relu5_3_p = conv_relu(n.relu5_2_p, 512,197    #                                     param_names=('conv5_3_w', 'conv5_3_b'),198    #                                     fix_param=config.fix_vgg,199    #                                     finetune=config.finetune)200    # upsampling feature map201    #n.relu5_3_upsampling_p = L.Deconvolution(n.relu5_3_p,202    #                                         convolution_param=dict(num_output=512,203    #                                                                group=512,204    #                                                                kernel_size=4,205    #                                                                stride=2,206    #                                                                pad=1,207    #                                                                bias_term=False,208    #                                                                weight_filler=dict(type='bilinear')),209    #                                         param=[dict(lr_mult=0, decay_mult=0)])210    # spatial L2 norm211    #n.relu5_3_lrn_p = L.LRN(n.relu5_3_upsampling_p, local_size=1025, alpha=1025, beta=0.5, k=1e-16)212    # concat all skip features213    #n.feat_all2 = n.relu4_3_lrn_p214    n.feat_all2 = L.Concat(n.pool3_lrn_p, n.relu4_3_lrn_p, concat_param=dict(axis=1))215    #n.feat_all2 = L.Concat(n.pool3_lrn_p, n.relu4_3_lrn_p, n.relu5_3_lrn_p, concat_param=dict(axis=1))216    # Dyn conv layer217    n.fcn_scores = L.DynamicConvolution(n.feat_all2, n.feat_all1_crop,218                                        convolution_param=dict(num_output=1,219                                                               kernel_size=11,220                                                               stride=1,221                                                               pad=5,222                                                               bias_term=False))223    return n.to_proto()224def generate_model(split, config):225    n = caffe.NetSpec()226    dataset = config.dataset227    batch_size = config.N228    mode_str = str(dict(dataset=dataset, split=split, batch_size=batch_size))229    n.image1, n.image2, n.label, n.sample_weights, n.feat_crop = L.Python(module=config.data_provider,230                                                                          layer=config.data_provider_layer,231                                                                          param_str=mode_str,232                                                                          ntop=5)233    ################################234    # the base net (VGG-16) branch 1235    n.conv1_1, n.relu1_1 = conv_relu(n.image1, 64,236                                     param_names=('conv1_1_w', 'conv1_1_b'),237                                     fix_param=True,238                                     finetune=False)239    n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64,240                                     param_names=('conv1_2_w', 'conv1_2_b'),241                                     fix_param=True,242                                     finetune=False)243    n.pool1 = max_pool(n.relu1_2)244    n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128,245                                     param_names=('conv2_1_w', 'conv2_1_b'),246                                     fix_param=True,247                                     finetune=False)248    n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128,249                                     param_names=('conv2_2_w', 'conv2_2_b'),250                                     fix_param=True,251                                     finetune=False)252    n.pool2 = max_pool(n.relu2_2)253    n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256,254                                     param_names=('conv3_1_w', 'conv3_1_b'),255                                     fix_param=config.fix_vgg,256                                     finetune=config.finetune)257    n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256,258                                     param_names=('conv3_2_w', 'conv3_2_b'),259                                     fix_param=config.fix_vgg,260                                     finetune=config.finetune)261    n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256,262                                     param_names=('conv3_3_w', 'conv3_3_b'),263                                     fix_param=config.fix_vgg,264                                     finetune=config.finetune)265    n.pool3 = max_pool(n.relu3_3)266    # spatial L2 norm267    n.pool3_lrn = L.LRN(n.pool3, local_size=513, alpha=513, beta=0.5, k=1e-16)268    n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512,269                                     param_names=('conv4_1_w', 'conv4_1_b'),270                                     fix_param=config.fix_vgg,271                                     finetune=config.finetune)272    n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512,273                                     param_names=('conv4_2_w', 'conv4_2_b'),274                                     fix_param=config.fix_vgg,275                                     finetune=config.finetune)276    n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512,277                                     param_names=('conv4_3_w', 'conv4_3_b'),278                                     fix_param=config.fix_vgg,279                                     finetune=config.finetune)280    # spatial L2 norm281    n.relu4_3_lrn = L.LRN(n.relu4_3, local_size=1025, alpha=1025, beta=0.5, k=1e-16)282    #n.pool4 = max_pool(n.relu4_3)283    #n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512,284    #                                 param_names=('conv5_1_w', 'conv5_1_b'),285    #                                 fix_param=config.fix_vgg,286    #                                 finetune=config.finetune)287    #n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512,288    #                                 param_names=('conv5_2_w', 'conv5_2_b'),289    #                                 fix_param=config.fix_vgg,290    #                                 finetune=config.finetune)291    #n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512,292    #                                 param_names=('conv5_3_w', 'conv5_3_b'),293    #                                 fix_param=config.fix_vgg,294    #                                 finetune=config.finetune)295    # upsampling feature map296    #n.relu5_3_upsampling = L.Deconvolution(n.relu5_3,297    #                                       convolution_param=dict(num_output=512,298    #                                                              group=512,299    #                                                              kernel_size=4,300    #                                                              stride=2,301    #                                                              pad=1,302    #                                                              bias_term=False,303    #                                                              weight_filler=dict(type='bilinear')),304    #                                       param=[dict(lr_mult=0, decay_mult=0)])305    # spatial L2 norm306    #n.relu5_3_lrn = L.LRN(n.relu5_3_upsampling, local_size=1025, alpha=1025, beta=0.5, k=1e-16)307    # concat all skip features308    #n.feat_all1 = n.relu4_3_lrn309    n.feat_all1 = L.Concat(n.pool3_lrn, n.relu4_3_lrn, concat_param=dict(axis=1))310    #n.feat_all1 = L.Concat(n.pool3_lrn, n.relu4_3_lrn, n.relu5_3_lrn, concat_param=dict(axis=1))311    n.feat_all1_crop = L.Crop(n.feat_all1, n.feat_crop, crop_param=dict(axis=2, offset=[config.query_featmap_H//3, config.query_featmap_W//3]))312    ################################313    # the base net (VGG-16) branch 2314    n.conv1_1_p, n.relu1_1_p = conv_relu(n.image2, 64,315                                         param_names=('conv1_1_w', 'conv1_1_b'),316                                         fix_param=True,317                                         finetune=False)318    n.conv1_2_p, n.relu1_2_p = conv_relu(n.relu1_1_p, 64,319                                         param_names=('conv1_2_w', 'conv1_2_b'),320                                         fix_param=True,321                                         finetune=False)322    n.pool1_p = max_pool(n.relu1_2_p)323    n.conv2_1_p, n.relu2_1_p = conv_relu(n.pool1_p, 128,324                                         param_names=('conv2_1_w', 'conv2_1_b'),325                                         fix_param=True,326                                         finetune=False)327    n.conv2_2_p, n.relu2_2_p = conv_relu(n.relu2_1_p, 128,328                                         param_names=('conv2_2_w', 'conv2_2_b'),329                                         fix_param=True,330                                         finetune=False)331    n.pool2_p = max_pool(n.relu2_2_p)332    n.conv3_1_p, n.relu3_1_p = conv_relu(n.pool2_p, 256,333                                         param_names=('conv3_1_w', 'conv3_1_b'),334                                         fix_param=config.fix_vgg,335                                         finetune=config.finetune)336    n.conv3_2_p, n.relu3_2_p = conv_relu(n.relu3_1_p, 256,337                                         param_names=('conv3_2_w', 'conv3_2_b'),338                                         fix_param=config.fix_vgg,339                                         finetune=config.finetune)340    n.conv3_3_p, n.relu3_3_p = conv_relu(n.relu3_2_p, 256,341                                         param_names=('conv3_3_w', 'conv3_3_b'),342                                         fix_param=config.fix_vgg,343                                         finetune=config.finetune)344    n.pool3_p = max_pool(n.relu3_3_p)345    # spatial L2 norm346    n.pool3_lrn_p = L.LRN(n.pool3_p, local_size=513, alpha=513, beta=0.5, k=1e-16)347    n.conv4_1_p, n.relu4_1_p = conv_relu(n.pool3_p, 512,348                                         param_names=('conv4_1_w', 'conv4_1_b'),349                                         fix_param=config.fix_vgg,350                                         finetune=config.finetune)351    n.conv4_2_p, n.relu4_2_p = conv_relu(n.relu4_1_p, 512,352                                         param_names=('conv4_2_w', 'conv4_2_b'),353                                         fix_param=config.fix_vgg,354                                         finetune=config.finetune)355    n.conv4_3_p, n.relu4_3_p = conv_relu(n.relu4_2_p, 512,356                                         param_names=('conv4_3_w', 'conv4_3_b'),357                                         fix_param=config.fix_vgg,358                                         finetune=config.finetune)359    # spatial L2 norm360    n.relu4_3_lrn_p = L.LRN(n.relu4_3_p, local_size=1025, alpha=1025, beta=0.5, k=1e-16)361    #n.pool4_p = max_pool(n.relu4_3_p)362    #n.conv5_1_p, n.relu5_1_p = conv_relu(n.pool4_p, 512,363    #                                     param_names=('conv5_1_w', 'conv5_1_b'),364    #                                     fix_param=config.fix_vgg,365    #                                     finetune=config.finetune)366    #n.conv5_2_p, n.relu5_2_p = conv_relu(n.relu5_1_p, 512,367    #                                     param_names=('conv5_2_w', 'conv5_2_b'),368    #                                     fix_param=config.fix_vgg,369    #                                     finetune=config.finetune)370    #n.conv5_3_p, n.relu5_3_p = conv_relu(n.relu5_2_p, 512,371    #                                     param_names=('conv5_3_w', 'conv5_3_b'),372    #                                     fix_param=config.fix_vgg,373    #                                     finetune=config.finetune)374    # upsampling feature map375    #n.relu5_3_upsampling_p = L.Deconvolution(n.relu5_3_p,376    #                                         convolution_param=dict(num_output=512,377    #                                                                group=512,378    #                                                                kernel_size=4,379    #                                                                stride=2,380    #                                                                pad=1,381    #                                                                bias_term=False,382    #                                                                weight_filler=dict(type='bilinear')),383    #                                         param=[dict(lr_mult=0, decay_mult=0)])384    # spatial L2 norm385    #n.relu5_3_lrn_p = L.LRN(n.relu5_3_upsampling_p, local_size=1025, alpha=1025, beta=0.5, k=1e-16)386    # concat all skip features387    #n.feat_all2 = n.relu4_3_lrn_p388    n.feat_all2 = L.Concat(n.pool3_lrn_p, n.relu4_3_lrn_p, concat_param=dict(axis=1))389    #n.feat_all2 = L.Concat(n.pool3_lrn_p, n.relu4_3_lrn_p, n.relu5_3_lrn_p, concat_param=dict(axis=1))390    # Dyn conv layer391    n.fcn_scores = L.DynamicConvolution(n.feat_all2, n.feat_all1_crop,392                                        convolution_param=dict(num_output=1,393                                                               kernel_size=11,394                                                               stride=1,395                                                               pad=5,396                                                               bias_term=False))397    398    # scale scores with zero mean 0.01196 -> 0.02677399    n.fcn_scaled_scores = L.Power(n.fcn_scores, power_param=dict(scale=0.01196,400                                                                 shift=-1.0,401                                                                 power=1))402    # Loss Layer403    n.loss = L.WeightedSigmoidCrossEntropyLoss(n.fcn_scaled_scores, n.label, n.sample_weights)...translations.py
Source:translations.py  
1"""2    Flowblade Movie Editor is a nonlinear video editor.3    Copyright 2012 Janne Liljeblad.4    This file is part of Flowblade Movie Editor <http://code.google.com/p/flowblade>.5    Flowblade Movie Editor is free software: you can redistribute it and/or modify6    it under the terms of the GNU General Public License as published by7    the Free Software Foundation, either version 3 of the License, or8    (at your option) any later version.9    Flowblade Movie Editor is distributed in the hope that it will be useful,10    but WITHOUT ANY WARRANTY; without even the implied warranty of11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the12    GNU General Public License for more details.13    You should have received a copy of the GNU General Public License14    along with Flowblade Movie Editor.  If not, see <http://www.gnu.org/licenses/>.15"""16import gettext17import locale18import os19import respaths20APP_NAME = "Flowblade"21lang = None22filter_groups = {}23filter_names = {}24param_names = {}25combo_options = {}26def init_languages():27    langs = []28    lc, encoding = locale.getdefaultlocale()29    if (lc):30        langs = [lc]31    print "Locale:", lc32    language = os.environ.get('LANGUAGE', None)33    if (language):34        langs += language.split(":")35    gettext.bindtextdomain(APP_NAME, respaths.LOCALE_PATH)36    gettext.textdomain(APP_NAME)37    # Get the language to use38    global lang39    #lang = gettext.translation(APP_NAME, respaths.LOCALE_PATH, languages=["fi"], fallback=True) # Testing, comment out for production40    lang = gettext.translation(APP_NAME, respaths.LOCALE_PATH, languages=langs, fallback=True)41    lang.install(APP_NAME) # makes _() a build-in available in all modules without imports42def get_filter_name(f_name):43    try:44        return filter_names[f_name]45    except KeyError:46        return f_name47def get_filter_group_name(group_name):48    try:49        return filter_groups[group_name]50    except:51        return group_name52def get_param_name(name):53    try:54        return param_names[name]55    except KeyError:56        return name57def get_combo_option(c_opt):58    try:59        return combo_options[c_opt]60    except KeyError:61        return c_opt62        63def load_filters_translations():64    # filter group names65    global filter_groups66    filter_groups["Color"] = _("Color")67    filter_groups["Color Effect"] = _("Color Effect")68    filter_groups["Audio"] = _("Audio")69    filter_groups["Audio Filter"] = _("Audio Filter")70    filter_groups["Blur"] = _("Blur")71    filter_groups["Distort"] = _("Distort")72    filter_groups["Alpha"] = _("Alpha")73    filter_groups["Movement"] = _("Movement")74    filter_groups["Transform"] = _("Transform")75    filter_groups["Edge"] = _("Edge")76    filter_groups["Fix"] = _("Fix")77    filter_groups["Artistic"] = _("Artistic")78    # filter names79    global filter_names80    filter_names["Alpha Gradient"] = _("Alpha Gradient")81    filter_names["Crop"] = _("Crop")82    filter_names["Alpha Shape"]= _("Alpha Shape")83    84    filter_names["Volume"]= _("Volume")85    filter_names["Pan"]= _("Pan")86    filter_names["Pan Keyframed"]= _("Pan Keyframed")87    filter_names["Mono to Stereo"]= _("Mono to Stereo")88    filter_names["Swap Channels"]= _("Swap Channels")89    filter_names["Pitchshifter"]= _("Pitchshifter")90    filter_names["Distort - Barry's Satan"]= _("Distort - Barry's Satan")91    filter_names["Frequency Shift - Bode/Moog"]= _("Frequency Shift - Bode/Moog")92    filter_names["Equalize - DJ 3-band"]= _("Equalize - DJ 3-band")93    filter_names["Flanger - DJ"]= _("Flanger - DJ")94    filter_names["Declipper"]= _("Declipper")95    filter_names["Delayorama"]= _("Delayorama")96    filter_names["Distort - Diode Processor"]= _("Distort - Diode Processor")97    filter_names["Distort - Foldover"]= _("Distort - Foldover")98    filter_names["Highpass - Butterworth"]= _("Highpass - Butterworth")99    filter_names["Lowpass - Butterworth"]= _("Lowpass - Butterworth")100    filter_names["GSM Simulator"]= _("GSM Simulator")101    filter_names["Reverb - GVerb"]= _("Reverb - GVerb")102    filter_names["Noise Gate"]= _("Noise Gate")103    filter_names["Bandpass"]= _("Bandpass")104    filter_names["Pitchscaler - High Quality"]= _("Pitchscaler - High Quality")105    filter_names["Equalize - Multiband"]= _("Equalize - Multiband")106    filter_names["Reverb - Plate"]= _("Reverb - Plate")107    filter_names["Distort - Pointer cast"]= _("Distort - Pointer cast")108    filter_names["Rate Shifter"]= _("Rate Shifter")109    filter_names["Signal Shifter"]= _("Signal Shifter")110    filter_names["Distort - Sinus Wavewrap"]= _("Distort - Sinus Wavewrap")111    filter_names["Vinyl Effect"]= _("Vinyl Effect")112    filter_names["Chorus - Multivoice"]= _("Chorus - Multivoice")113    filter_names["Charcoal"]= _("Charcoal")114    filter_names["Glow"]= _("Glow")115    filter_names["Old Film"]= _("Old Film")116    filter_names["Scanlines"]= _("Scanlines")117    filter_names["Cartoon"]= _("Cartoon")118    119    filter_names["Pixelize"]= _("Pixelize")120    filter_names["Blur"]= _("Blur")121    filter_names["Grain"]= _("Grain")122    123    filter_names["Grayscale"]= _("Grayscale")124    filter_names["Contrast"]= _("Contrast")125    filter_names["Saturation"]= _("Saturation")126    filter_names["Invert"]= _("Invert")127    filter_names["Hue"]= _("Hue")128    filter_names["Brightness"]= _("Brightness")129    filter_names["Sepia"]= _("Sepia")130    filter_names["Tint"]= _("Tint")131    filter_names["White Balance"]= _("White Balance")132    filter_names["Levels"]= _("Levels")133    filter_names["Color Clustering"]= _("Color Clustering")134    filter_names["Chroma Hold"]= _("Chroma Hold")135    filter_names["Three Layer"]= _("Three Layer")136    filter_names["Threshold0r"]= _("Threshold0r")137    filter_names["Technicolor"]= _("Technicolor")138    filter_names["Primaries"]= _("Primaries")139    filter_names["Color Distance"]= _("Color Distance")140    filter_names["Threshold"]= _("Threshold")141    filter_names["Waves"]= _("Waves")142    filter_names["Lens Correction"]= _("Lens Correction")143    filter_names["Flip"]= _("Flip")144    filter_names["Mirror"]= _("Mirror")145    filter_names["V Sync"]= _("V Sync")146    filter_names["Edge Glow"]= _("Edge Glow")147    filter_names["Sobel"]= _("Sobel")148    filter_names["Denoise"]= _("Denoise")149    filter_names["Sharpness"]= _("Sharpness")150    filter_names["Letterbox"]= _("Letterbox")151    filter_names["Baltan"]= _("Baltan")152    filter_names["Vertigo"]= _("Vertigo")153    filter_names["Nervous"]= _("Nervous")154    filter_names["Freeze"]= _("Freeze")155    filter_names["Rotate"]= _("Rotate")156    filter_names["Shear"]= _("Shear")157    filter_names["Translate"]= _("Translate")158    # 0.8 added159    filter_names["Color Select"]= _("Color Select")160    filter_names["Alpha Modify"]= _("Alpha Modify")161    filter_names["Spill Supress"]= _("Spill Supress")162    filter_names["RGB Noise"]= _("RGB Noise")163    filter_names["Box Blur"]= _("Box Blur")164    filter_names["IRR Blur"]= _("IRR Blur")165    filter_names["Color Halftone"]= _("Color Halftone")166    filter_names["Dither"]= _("Dither")167    filter_names["Vignette"]= _("Vignette")168    filter_names["Emboss"]= _("Emboss")169    filter_names["3 Point Balance"]= _("3 Point Balance")170    filter_names["Colorize"]= _("Colorize")171    filter_names["Brightness Keyframed"]= _("Brightness Keyframed")172    filter_names["RGB Adjustment"]= _("RGB Adjustment")173    filter_names["Color Tap"]= _("Color Tap")174    filter_names["Posterize"]= _("Posterize")175    filter_names["Soft Glow"]= _("Soft Glow")176    filter_names["Newspaper"]= _("Newspaper")177    # 0.16 added178    filter_names["Luma Key"] = _("Luma Key")179    filter_names["Chroma Key"] = _("Chroma Key")180    filter_names["Affine"] = _("Affine")181    filter_names["Color Adjustment"] = _("Color Adjustment")182    filter_names["Color Grading"] = _("Color Grading")183    filter_names["Curves"] = _("Curves")184    filter_names["Lift Gain Gamma"] = _("Lift Gain Gamma")185    filter_names["Image Grid"] = _("Image Grid")186    187    # 0.18188    filter_names["Color Lift Gain Gamma"] = _("Color Lift Gain Gamma")189    190    # param names191    global param_names192    # param names for filters193    param_names["Position"] = _("Position")194    param_names["Grad width"] = _("Grad width")195    param_names["Tilt"] = _("Tilt")196    param_names["Min"] = _("Min")197    param_names["Max"] = _("Max")198    param_names["Left"] = _("Left")199    param_names["Right"] = _("Right")200    param_names["Top"] = _("Top")201    param_names["Bottom"] = _("Bottom")202    param_names["Shape"] = _("Shape")203    param_names["Pos X"] = _("Pos X")204    param_names["Pos Y"] = _("Pos Y")205    param_names["Size X"] = _("Size X")206    param_names["Size Y"] = _("Size Y")207    param_names["Tilt"] = _("Tilt")208    param_names["Trans. Width"] = _("Trans. Width")209    param_names["Volume"] = _("Volume")210    param_names["Left/Right"] = _("Left/Right")211    param_names["Left/Right"] = _("Left/Right")212    param_names["Dry/Wet"] = _("Dry/Wet")213    param_names["Pitch Shift"] = _("Pitch Shift")214    param_names["Buffer Size"] = _("Buffer Size")215    param_names["Dry/Wet"] = _("Dry/Wet")216    param_names["Decay Time(samples)"] = _("Decay Time(samples)")217    param_names["Knee Point(dB)"] = _("Knee Point(dB)")218    param_names["Dry/Wet"] = _("Dry/Wet")219    param_names["Frequency shift"] = _("Frequency shift")220    param_names["Dry/Wet"] = _("Dry/Wet")221    param_names["Low Gain(dB)"] = _("Low Gain(dB)")222    param_names["Mid Gain(dB)"] = _("Mid Gain(dB)")223    param_names["High Gain(dB)"] = _("High Gain(dB)")224    param_names["Dry/Wet"] = _("Dry/Wet")225    param_names["Oscillation period(s)"] = _("Oscillation period(s)")226    param_names["Oscillation depth(ms)"] = _("Oscillation depth(ms)")227    param_names["Feedback%"] = _("Feedback%")228    param_names["Dry/Wet"] = _("Dry/Wet")229    param_names["Dry/Wet"] = _("Dry/Wet")230    param_names["Random seed"] = _("Random seed")231    param_names["Input Gain(dB)"] = _("Input Gain(dB)")232    param_names["Feedback(%)"] = _("Feedback(%)")233    param_names["Number of taps"] = _("Number of taps")234    param_names["First Delay(s)"] = _("First Delay(s)")235    param_names["Delay Range(s)"] = _("Delay Range(s)")236    param_names["Delay Change"] = _("Delay Change")237    param_names["Delay Random(%)"] = _("Delay Random(%)")238    param_names["Amplitude Change"] = _("Amplitude Change")239    param_names["Amplitude Random(%)"] = _("Amplitude Random(%)")240    param_names["Dry/Wet"] = _("Dry/Wet")241    param_names["Amount"] = _("Amount")242    param_names["Dry/Wet"] = _("Dry/Wet")243    param_names["Drive"] = _("Drive")244    param_names["Skew"] = _("Skew")245    param_names["Dry/Wet"] = _("Dry/Wet")246    param_names["Cutoff Frequency(Hz)"] = _("Cutoff Frequency(Hz)")247    param_names["Resonance"] = _("Resonance")248    param_names["Dry/Wet"] = _("Dry/Wet")249    param_names["Cutoff Frequency(Hz)"] = _("Cutoff Frequency(Hz)")250    param_names["Resonance"] = _("Resonance")251    param_names["Dry/Wet"] = _("Dry/Wet")252    param_names["Passes"] = _("Passes")253    param_names["Error Rate"] = _("Error Rate")254    param_names["Dry/Wet"] = _("Dry/Wet")255    param_names["Roomsize"] = _("Roomsize")256    param_names["Reverb time(s)"] = _("Reverb time(s)")257    param_names["Damping"] = _("Damping")258    param_names["Input bandwith"] = _("Input bandwith")259    param_names["Dry signal level(dB)"] = _("Dry signal level(dB)")260    param_names["Early reflection level(dB)"] = _("Early reflection level(dB)")261    param_names["Tail level(dB)"] = _("Tail level(dB)")262    param_names["Dry/Wet"] = _("Dry/Wet")263    param_names["LF keyfilter(Hz)"] = _("LF keyfilter(Hz)")264    param_names["HF keyfilter(Hz)"] = _("HF keyfilter(Hz)")265    param_names["Threshold(dB)"] = _("Threshold(dB)")266    param_names["Attack(ms)"] = _("Attack(ms)")267    param_names["Hold(ms)"] = _("Hold(ms)")268    param_names["Decay(ms)"] = _("Decay(ms)")269    param_names["Range(dB)"] = _("Range(dB)")270    param_names["Dry/Wet"] = _("Dry/Wet")271    param_names["Center Frequency(Hz)"] = _("Center Frequency(Hz)")272    param_names["Bandwidth(Hz)"] = _("Bandwidth(Hz)")273    param_names["Stages"] = _("Stages")274    param_names["Dry/Wet"] = _("Dry/Wet")275    param_names["Pitch-coefficient"] = _("Pitch-coefficient")276    param_names["Dry/Wet"] = _("Dry/Wet")277    param_names["50Hz gain"] = _("50Hz gain")278    param_names["100Hz gain"] = _("100Hz gain")279    param_names["156Hz gain"] = _("156Hz gain")280    param_names["220Hz gain"] = _("220Hz gain")281    param_names["311Hz gain"] = _("311Hz gain")282    param_names["440Hz gain"] = _("440Hz gain")283    param_names["622Hz gain"] = _("622Hz gain")284    param_names["880Hz gain"] = _("880Hz gain")285    param_names["1250Hz gain"] = _("1250Hz gain")286    param_names["1750Hz gain"] = _("1750Hz gain")287    param_names["2500Hz gain"] = _("2500Hz gain")288    param_names["3500Hz gain"] = _("3500Hz gain")289    param_names["5000Hz gain"] = _("5000Hz gain")290    param_names["100000Hz gain"] = _("100000Hz gain")291    param_names["200000Hz gain"] = _("200000Hz gain")292    param_names["Dry/Wet"] = _("Dry/Wet")293    param_names["Reverb time"] = _("Reverb time")294    param_names["Damping"] = _("Damping")295    param_names["Dry/Wet mix"] = _("Dry/Wet mix")296    param_names["Dry/Wet"] = _("Dry/Wet")297    param_names["Effect cutoff(Hz)"] = _("Effect cutoff(Hz)")298    param_names["Dry/Wet mix"] = _("Dry/Wet mix")299    param_names["Dry/Wet"] = _("Dry/Wet")300    param_names["Rate"] = _("Rate")301    param_names["Dry/Wet"] = _("Dry/Wet")302    param_names["Sift"] = _("Sift")303    param_names["Dry/Wet"] = _("Dry/Wet")304    param_names["Amount"] = _("Amount")305    param_names["Dry/Wet"] = _("Dry/Wet")306    param_names["Year"] = _("Year")307    param_names["RPM"] = _("RPM")308    param_names["Surface warping"] = _("Surface warping")309    param_names["Cracle"] = _("Cracle")310    param_names["Wear"] = _("Wear")311    param_names["Dry/Wet"] = _("Dry/Wet")312    param_names["Number of voices"] = _("Number of voices")313    param_names["Delay base(ms)"] = _("Delay base(ms)")314    param_names["Voice separation(ms)"] = _("Voice separation(ms)")315    param_names["Detune(%)"] = _("Detune(%)")316    param_names["Oscillation frequency(Hz)"] = _("Oscillation frequency(Hz)")317    param_names["Output attenuation(dB)"] = _("Output attenuation(dB)")318    param_names["Dry/Wet"] = _("Dry/Wet")319    param_names["X Scatter"] = _("X Scatter")320    param_names["Y Scatter"] = _("Y Scatter")321    param_names["Scale"] = _("Scale")322    param_names["Mix"] = _("Mix")323    param_names["Invert"] = _("Invert")324    param_names["Blur"] = _("Blur")325    param_names["Delta"] = _("Delta")326    param_names["Duration"] = _("Duration")327    param_names["Bright. up"] = _("Bright. up")328    param_names["Bright. down"] = _("Bright. down")329    param_names["Bright. dur."] = _("Bright. dur.")330    param_names["Develop up"] = _("Develop up")331    param_names["Develop down"] = _("Develop down")332    param_names["Develop dur."] = _("Develop dur.")333    param_names["Triplevel"] = _("Triplevel")334    param_names["Difference Space"] = _("Difference Space")335    param_names["Block width"] = _("Block width")336    param_names["Block height"] = _("Block height")337    param_names["Size"] = _("Size")338    param_names["Noise"] = _("Noise")339    param_names["Contrast"] = _("Contrast")340    param_names["Brightness"] = _("Brightness")341    param_names["Contrast"] = _("Contrast")342    param_names["Saturation"] = _("Saturation")343    param_names["Hue"] = _("Hue")344    param_names["Brightness"] = _("Brightness")345    param_names["Brightness"] = _("Brightness")346    param_names["U"] = _("U")347    param_names["V"] = _("V")348    param_names["Black"] = _("Black")349    param_names["White"] = _("White")350    param_names["Amount"] = _("Amount")351    param_names["Neutral Color"] = _("Neutral Color")352    param_names["Input"] = _("Input")353    param_names["Input"] = _("Input")354    param_names["Gamma"] = _("Gamma")355    param_names["Black"] = _("Black")356    param_names["White"] = _("White")357    param_names["Num"] = _("Num")358    param_names["Dist. weight"] = _("Dist. weight")359    param_names["Color"] = _("Color")360    param_names["Variance"] = _("Variance")361    param_names["Threshold"] = _("Threshold")362    param_names["Red Saturation"] = _("Red Saturation")363    param_names["Yellow Saturation"] = _("Yellow Saturation")364    param_names["Factor"] = _("Factor")365    param_names["Source color"] = _("Source color")366    param_names["Threshold"] = _("Threshold")367    param_names["Amplitude"] = _("Amplitude")368    param_names["Frequency"] = _("Frequency")369    param_names["Rotate"] = _("Rotate")370    param_names["Tilt"] = _("Tilt")371    param_names["Center Correct"] = _("Center Correct")372    param_names["Edges Correct"] = _("Edges Correct")373    param_names["Flip"] = _("Flip")374    param_names["Axis"] = _("Axis")375    param_names["Invert"] = _("Invert")376    param_names["Position"] = _("Position")377    param_names["Edge Lightning"] = _("Edge Lightning")378    param_names["Edge Brightness"] = _("Edge Brightness")379    param_names["Non-Edge Brightness"] = _("Non-Edge Brightness")380    param_names["Spatial"] = _("Spatial")381    param_names["Temporal"] = _("Temporal")382    param_names["Amount"] = _("Amount")383    param_names["Size"] = _("Size")384    param_names["Border width"] = _("Border width")385    param_names["Phase Incr."] = _("Phase Incr.")386    param_names["Zoom"] = _("Zoom")387    param_names["Freeze Frame"] = _("Freeze Frame")388    param_names["Freeze After"] = _("Freeze After")389    param_names["Freeze Before"] = _("Freeze Before")390    param_names["Angle"] = _("Angle")391    param_names["transition.geometry"] = _("transition.geometry")392    param_names["Shear X"] = _("Shear X")393    param_names["Shear Y"] = _("Shear Y")394    param_names["transition.geometry"] = _("transition.geometry")395    param_names["transition.geometry"] = _("transition.geometry")396    param_names["Left"] = _("Left")397    param_names["Right"] = _("Right")398    param_names["Top"] = _("Top")399    param_names["Bottom"] = _("Bottom")400    param_names["Invert"] = _("Invert")401    param_names["Blur"] = _("Blur")402    param_names["Opacity"] = _("Opacity")403    param_names["Opacity"] = _("Opacity")404    param_names["Rotate X"] = _("Rotate X")405    param_names["Rotate Y"] = _("Rotate Y")406    param_names["Rotate Z"] = _("Rotate Z")407    # added 0.8408    param_names["Edge Mode"] = _("Edge Mode")409    param_names["Sel. Space"] = _("Sel. Space")410    param_names["Operation"] = _("Operation")411    param_names["Hard"] = _("Hard")412    param_names["R/A/Hue"] = _("R/A/Hue")413    param_names["G/B/Chromae"] = _("G/B/Chroma")414    param_names["B/I/I"] = _("B/I/I")415    param_names["Supress"] = _("Supress")416    param_names["Horizontal"] = _("Horizontal")417    param_names["Vertical"] = _("Vertical")418    param_names["Type"] = _("Type")419    param_names["Edge"] = _("Edge")420    param_names["Dot Radius"] = _("Dot Radius")421    param_names["Cyan Angle"] = _("Cyan Angle")422    param_names["Magenta Angle"] = _("Magenta Angle")423    param_names["Yellow Angle"] = _("Yellow Angle")424    param_names["Levels"] = _("Levels")425    param_names["Matrix Type"] = _("Matrix Type")426    param_names["Aspect"] = _("Aspect")427    param_names["Center Size"] = _("Center Size")428    param_names["Azimuth"] = _("Azimuth")429    param_names["Lightness"] = _("Lightness")430    param_names["Bump Height"] = _("Bump Height")431    param_names["Gray"] = _("Gray")432    param_names["Split Preview"] = _("Split Preview")433    param_names["Source on Left"] = _("Source on Left")434    param_names["Lightness"] = _("Lightness")435    param_names["Input black level"] = _("Input black level")436    param_names["Input white level"] = _("Input white level")437    param_names["Black output"] = _("Black output")438    param_names["White output"] = _("White output")439    param_names["Red"] = _("Red")440    param_names["Green"] = _("Green")441    param_names["Blue"] = _("Blue")442    param_names["Action"] = _("Action")443    param_names["Keep Luma"] = _("Keep Luma")444    param_names["Luma Formula"] = _("Luma Formula")445    param_names["Effect"] = _("Effect")446    param_names["Sharpness"] = _("Sharpness")447    param_names["Blend Type"] = _("Blend Type")448    # added 0.16449    param_names["Key Color"] = _("Key Color")450    param_names["Pre-Level"] = _("Pre-Level")451    param_names["Post-Level"] = _("Post-Level")452    param_names["Slope"] = _("Slope")453    param_names["Luma Band"] = _("Luma Band")454    param_names["Lift"] = _("Lift")455    param_names["Gain"] = _("Gain")456    param_names["Input White Level"] = _("Input White Level")457    param_names["Input Black Level"] = _("Input Black Level")458    param_names["Black Output"] = _("Black Output")459    param_names["White Output"] = _("White Output")460    param_names["Rows"] = _("Rows")461    param_names["Columns"] = _("Columns")462    param_names["Color Temperature"] = _("Color Temperature")463    # param names for compositors464    param_names["Opacity"] = _("Opacity")465    param_names["Shear X"] = _("Shear X")466    param_names["Shear Y"] = _("Shear Y")467    param_names["Distort"] = _("Distort")468    param_names["Opacity"] = _("Opacity")469    param_names["Wipe Type"] = _("Wipe Type")470    param_names["Invert"] = _("Invert")471    param_names["Softness"] = _("Softness")472    param_names["Wipe Amount"] = _("Wipe Amount")473    param_names["Wipe Type"] = _("Wipe Type")474    param_names["Invert"] = _("Invert")475    param_names["Softness"] = _("Softness")476    # Combo options477    global combo_options478    combo_options["Shave"] = _("Shave")479    combo_options["Rectangle"] = _("Rectangle")480    combo_options["Ellipse"] = _("Ellipse")481    combo_options["Triangle"] = _("Triangle")482    combo_options["Diamond"] = _("Diamond")483    combo_options["Shave"] = _("Shave")484    combo_options["Shrink Hard"] = _("Shrink Hard")485    combo_options["Shrink Soft"] = _("Shrink Soft")486    combo_options["Grow Hard"] = _("Grow Hard")487    combo_options["Grow Soft"] = _("Grow Soft")488    combo_options["RGB"] = _("RGB")489    combo_options["ABI"] = _("ABI")490    combo_options["HCI"] = _("HCI")491    combo_options["Hard"] = _("Hard")492    combo_options["Fat"] = _("Fat")493    combo_options["Normal"] = _("Normal")494    combo_options["Skinny"] = _("Skinny")495    combo_options["Ellipsoid"] = _("Ellipsoid")496    combo_options["Diamond"] = _("Diamond")497    combo_options["Overwrite"] = _("Overwrite")498    combo_options["Max"] = _("Max")499    combo_options["Min"] = _("Min")500    combo_options["Add"] = _("Add")501    combo_options["Subtract"] = _("Subtract")502    combo_options["Green"] = _("Green")503    combo_options["Blue"] = _("Blue")504    combo_options["Sharper"] = _("Sharper")505    combo_options["Fuzzier"] = _("Fuzzier")506    combo_options["Luma"] = _("Luma")507    combo_options["Red"] = _("Red")508    combo_options["Green"] = _("Green")509    combo_options["Blue"] = _("Blue")510    combo_options["Add Constant"] = _("Add Constant")511    combo_options["Change Gamma"] = _("Change Gamma")512    combo_options["Multiply"] = _("Multiply")513    combo_options["XPro"] = _("XPro")514    combo_options["OldPhoto"] = _("OldPhoto")515    combo_options["Sepia"] = _("Sepia")516    combo_options["Heat"] = _("Heat")517    combo_options["XRay"] = _("XRay")518    combo_options["RedGreen"] = _("RedGreen")519    combo_options["YellowBlue"] = _("YellowBlue")520    combo_options["Esses"] = _("Esses")521    combo_options["Horizontal"] = _("Horizontal")522    combo_options["Vertical"] = _("Vertical")523    combo_options["Shadows"] = _("Shadows")524    combo_options["Midtones"] = _("Midtones")525    combo_options["Highlights"] = _("Highlights")...interpolate_core_collapse_timescale.py
Source:interpolate_core_collapse_timescale.py  
1import numpy as np2from sidmpy.core_collapse_timescale import fraction_collapsed_halos, fraction_collapsed_halos_pool3from scipy.interpolate import RegularGridInterpolator4from scipy.interpolate import interp1d5import pickle6from multiprocess.pool import Pool7class InterpolatedCollapseTimescale(object):8    def __init__(self, points, values, param_names, param_arrays):9        self.param_names = param_names10        self.param_ranges = []11        self.param_ranges_dict = {}12        for i, param in enumerate(param_arrays):13            ran = [param[0], param[-1]]14            self.param_ranges.append(ran)15            self.param_ranges_dict[param_names[i]] = ran16        self._interp_function = RegularGridInterpolator(points, values,17                                                        bounds_error=False, fill_value=None)18    @classmethod19    def fromParamArray(self, m1, m2, cross_section_model, param_names, param_arrays, params_fixed={},20                 kwargs_fraction={}, nproc=8):21        param_names = param_names22        param_ranges = []23        param_ranges_dict = {}24        for i, param in enumerate(param_arrays):25            ran = [param[0], param[-1]]26            param_ranges.append(ran)27            param_ranges_dict[param_names[i]] = ran28        print('param_names: ', param_names)29        print('n params: ', len(param_names))30        print('n sample arrays: ', len(param_arrays))31        # redshift is always last32        if len(param_arrays) == 2:33            args_list = []34            points = (param_arrays[0], param_arrays[1])35            n_total = len(param_arrays[0]) * len(param_arrays[1])36            print('n total: ', n_total)37            for p1 in param_arrays[0]:38                for redshift in param_arrays[1]:39                    kw = {param_names[0]: p1}40                    kw.update(params_fixed)41                    kwargs_fraction['redshift'] = redshift42                    cross_model = cross_section_model(**kw)43                    new = (m1, m2, cross_model, kwargs_fraction['redshift'], kwargs_fraction['timescale_factor'])44                    args_list.append(new)45            shape = (len(param_arrays[0]), len(param_arrays[1]))46        elif len(param_arrays) == 3:47            args_list = []48            points = (param_arrays[0], param_arrays[1], param_arrays[2])49            n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2])50            print('n total: ', n_total)51            for p1 in param_arrays[0]:52                for p2 in param_arrays[1]:53                    for redshift in param_arrays[2]:54                        # if counter % step == 0:55                        #     print(str(np.round(100 * counter / n_total, 1)) + '% ')56                        kw = {param_names[0]: p1, param_names[1]: p2}57                        kw.update(params_fixed)58                        kwargs_fraction['redshift'] = redshift59                        cross_model = cross_section_model(**kw)60                        new = (m1, m2, cross_model, kwargs_fraction['redshift'], kwargs_fraction['timescale_factor'])61                        args_list.append(new)62            shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]))63        elif len(param_arrays) == 4:64            points = (param_arrays[0], param_arrays[1], param_arrays[2], param_arrays[3])65            n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2]) * len(param_arrays[3])66            print('n total: ', n_total)67            args_list = []68            for p1 in param_arrays[0]:69                for p2 in param_arrays[1]:70                    for p3 in param_arrays[2]:71                        for redshift in param_arrays[3]:72                            kw = {param_names[0]: p1, param_names[1]: p2, param_names[2]: p3}73                            kw.update(params_fixed)74                            kwargs_fraction['redshift'] = redshift75                            cross_model = cross_section_model(**kw)76                            new = (m1, m2, cross_model, kwargs_fraction['redshift'], kwargs_fraction['timescale_factor'])77                            args_list.append(new)78            pool = Pool(nproc)79            values = pool.map(fraction_collapsed_halos_pool, args_list)80            pool.close()81            shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]), len(param_arrays[3]))82        elif len(param_arrays) == 5:83            points = (param_arrays[0], param_arrays[1], param_arrays[2], param_arrays[3], param_arrays[4])84            n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2]) * len(param_arrays[3]) * len(param_arrays[4])85            print('n total: ', n_total)86            args_list = []87            for p1 in param_arrays[0]:88                for p2 in param_arrays[1]:89                    for p3 in param_arrays[2]:90                        for p4 in param_arrays[3]:91                            for redshift in param_arrays[4]:92                                # if counter % step == 0:93                                #     print(str(np.round(100 * counter / n_total, 1)) + '% ')94                                kw = {param_names[0]: p1, param_names[1]: p2, param_names[2]: p3, param_names[3]: p4}95                                kw.update(params_fixed)96                                kwargs_fraction['redshift'] = redshift97                                cross_model = cross_section_model(**kw)98                                new = (99                                m1, m2, cross_model, kwargs_fraction['redshift'], kwargs_fraction['timescale_factor'])100                                args_list.append(new)101            shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]), len(param_arrays[3]),102                     len(param_arrays[4]))103        elif len(param_arrays) == 6:104            points = (param_arrays[0], param_arrays[1], param_arrays[2], param_arrays[3], param_arrays[4], param_arrays[5])105            n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2]) * len(param_arrays[3]) * len(106                param_arrays[4]) * len(param_arrays[5])107            print('n total: ', n_total)108            args_list = []109            for p1 in param_arrays[0]:110                for p2 in param_arrays[1]:111                    for p3 in param_arrays[2]:112                        for p4 in param_arrays[3]:113                            for p5 in param_arrays[4]:114                                for redshift in param_arrays[5]:115                                    # if counter % step == 0:116                                    #     print(str(np.round(100 * counter / n_total, 1)) + '% ')117                                    kw = {param_names[0]: p1, param_names[1]: p2, param_names[2]: p3, param_names[3]: p4,118                                          param_names[4]: p5}119                                    kw.update(params_fixed)120                                    kwargs_fraction['redshift'] = redshift121                                    cross_model = cross_section_model(**kw)122                                    new = (123                                        m1, m2, cross_model, kwargs_fraction['redshift'],124                                        kwargs_fraction['timescale_factor'])125                                    args_list.append(new)126            pool = Pool(nproc)127            values = pool.map(fraction_collapsed_halos_pool, args_list)128            pool.close()129            shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]), len(param_arrays[3]),130                     len(param_arrays[4]), len(param_arrays[5]))131        elif len(param_arrays) == 7:132            points = (param_arrays[0], param_arrays[1], param_arrays[2], param_arrays[3], param_arrays[4],133                      param_arrays[5], param_arrays[6])134            n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2]) * len(param_arrays[3]) * len(135                param_arrays[4]) * len(param_arrays[5] * len(param_arrays[6]))136            print('n total: ', n_total)137            args_list = []138            for p1 in param_arrays[0]:139                for p2 in param_arrays[1]:140                    for p3 in param_arrays[2]:141                        for p4 in param_arrays[3]:142                            for p5 in param_arrays[4]:143                                for timescale_factor in param_arrays[5]:144                                    for redshift in param_arrays[6]:145                                        kw = {param_names[0]: p1, param_names[1]: p2, param_names[2]: p3, param_names[3]: p4,146                                              param_names[4]: p5}147                                        kw.update(params_fixed)148                                        kwargs_fraction['redshift'] = redshift149                                        kwargs_fraction['timescale_factor'] = timescale_factor150                                        cross_model = cross_section_model(**kw)151                                        new = (152                                            m1, m2, cross_model, kwargs_fraction['redshift'],153                                            kwargs_fraction['timescale_factor'])154                                        args_list.append(new)155            pool = Pool(nproc)156            values = pool.map(fraction_collapsed_halos_pool, args_list)157            pool.close()158            shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]), len(param_arrays[3]),159                     len(param_arrays[4]), len(param_arrays[5]), len(param_arrays[6]))160        else:161            raise Exception('only 2, 3, 4 and 5D interpolations implemented')162        return InterpolatedCollapseTimescale(points, values, param_names, param_arrays, shape)163    def __call__(self, *args):164        return np.squeeze(self._interp_function(tuple(args)))165def interpolate_collapse_fraction(fname, cross_section_class, param_names, param_arrays, params_fixed, m1,166                                  kwargs_collapse_fraction, nproc):167    interp_timescale = InterpolatedCollapseTimescale(m1, m1 * 1.05, cross_section_class,168                                                     param_names, param_arrays, params_fixed, kwargs_collapse_fraction, nproc=nproc)169    f = open('interpolated_collapse_fraction_'+fname, 'wb')170    pickle.dump(interp_timescale, f)171    f.close()172# from sidmpy.CrossSections.resonant_tchannel import ExpResonantTChannel173# # norm, v_ref, v_res, w_res, res_amplitude174# param_names = ['norm', 'v_ref', 'v_res', 'w_res', 'res_amplitude', 'timescale_factor', 'redshift']175# cross_model = ExpResonantTChannel176#177# output_folder = ''178# nproc = 50179# params_fixed = {}180# kwargs_collapse_fraction = {}181# z_array = [0.2, 0.45, 0.7, 0.95]182# tarray = [10/3, 15/3, 20/3]183# param_arrays = [np.linspace(1, 10.0, 9), np.linspace(1, 50.0, 20), np.linspace(1, 40, 20),184#                 np.linspace(1, 5.0, 5), np.linspace(1.0, 100, 40), tarray, z_array]185# n_total = 1186# for parr in param_arrays:187#     n_total *= len(parr)188# print('n_total: ', n_total); a=input('continue')189# fname = output_folder + 'logM68_expresonanttchannel'190# m1 = 10 ** 7191# interpolate_collapse_fraction(fname, cross_model, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction, nproc=nproc)192# fname = output_folder + 'logM89_expresonanttchannel'193# m1 = 10 ** 8.5194# interpolate_collapse_fraction(fname, cross_model, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction, nproc=nproc)195#196# fname = output_folder + 'logM910_expresonanttchannel'197# m1 = 10 ** 9.5198# interpolate_collapse_fraction(fname, cross_model, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction, nproc=nproc)199# from sidmpy.CrossSections.tchannel import TChannel200# param_names = ['norm', 'v_ref']201# n = 50202# params_fixed = {}203# kwargs_collapse_fraction = {'redshift': 0.5, 'timescale_factor': 20.0}204# param_arrays = [np.linspace(0.5, 60.0, n), np.linspace(1.0, 40, n)]205# fname = 'logM68_tchannel'206# m1 = 10 ** 7207# interpolate_collapse_fraction(fname, TChannel, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction)208#209# fname = 'logM89_tchannel'210# m1 = 5 * 10 ** 8211# interpolate_collapse_fraction(fname, TChannel, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction)212#213# fname = 'logM910_tchannel'214# m1 = 5 * 10 ** 9215# interpolate_collapse_fraction(fname, TChannel, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction)...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
