@@ -967,6 +967,7 @@ def test_get_diff_vars_replace_custom_schema(self):
967967 mock_model .database = "a_dev_db"
968968 mock_model .schema_ = "a_custom_schema"
969969 mock_model .config .schema_ = mock_model .schema_
970+ mock_model .config .database = None
970971 mock_model .alias = "a_model_name"
971972 mock_tdatadiffmodelconfig = Mock ()
972973 mock_tdatadiffmodelconfig .where_filter = "where"
@@ -999,6 +1000,7 @@ def test_get_diff_vars_static_custom_schema(self):
9991000 primary_keys = ["a_primary_key" ]
10001001 mock_model .database = "a_dev_db"
10011002 mock_model .schema_ = "a_custom_schema"
1003+ mock_model .config .database = None
10021004 mock_model .config .schema_ = mock_model .schema_
10031005 mock_model .alias = "a_model_name"
10041006 mock_tdatadiffmodelconfig = Mock ()
@@ -1031,6 +1033,7 @@ def test_get_diff_vars_no_custom_schema_on_model(self):
10311033 mock_model .database = "a_dev_db"
10321034 mock_model .schema_ = "a_custom_schema"
10331035 mock_model .config .schema_ = None
1036+ mock_model .config .database = None
10341037 mock_model .alias = "a_model_name"
10351038 mock_tdatadiffmodelconfig = Mock ()
10361039 mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1060,6 +1063,7 @@ def test_get_diff_vars_match_dev_schema(self):
10601063 mock_model .database = "a_dev_db"
10611064 mock_model .schema_ = "a_schema"
10621065 mock_model .config .schema_ = None
1066+ mock_model .config .database = None
10631067 mock_model .alias = "a_model_name"
10641068 mock_tdatadiffmodelconfig = Mock ()
10651069 mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1107,6 +1111,7 @@ def test_get_diff_vars_meta_where(self):
11071111 mock_model .database = "a_dev_db"
11081112 mock_model .schema_ = "a_schema"
11091113 mock_model .config .schema_ = None
1114+ mock_model .config .database = None
11101115 mock_model .alias = "a_model_name"
11111116 mock_tdatadiffmodelconfig = Mock ()
11121117 mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1136,6 +1141,7 @@ def test_get_diff_vars_meta_unrelated(self):
11361141 mock_model .database = "a_dev_db"
11371142 mock_model .schema_ = "a_schema"
11381143 mock_model .config .schema_ = None
1144+ mock_model .config .database = None
11391145 mock_model .alias = "a_model_name"
11401146 mock_tdatadiffmodelconfig = Mock ()
11411147 mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1165,6 +1171,7 @@ def test_get_diff_vars_meta_none(self):
11651171 mock_model .database = "a_dev_db"
11661172 mock_model .schema_ = "a_schema"
11671173 mock_model .config .schema_ = None
1174+ mock_model .config .database = None
11681175 mock_model .alias = "a_model_name"
11691176 mock_tdatadiffmodelconfig = Mock ()
11701177 mock_tdatadiffmodelconfig .where_filter = "where"
@@ -1176,7 +1183,6 @@ def test_get_diff_vars_meta_none(self):
11761183 mock_dbt_parser .threads = 0
11771184 mock_dbt_parser .get_pk_from_model .return_value = primary_keys
11781185 mock_dbt_parser .requires_upper = False
1179- where = None
11801186 mock_model .meta = None
11811187
11821188 diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
@@ -1188,3 +1194,34 @@ def test_get_diff_vars_meta_none(self):
11881194 assert diff_vars .threads == mock_dbt_parser .threads
11891195 self .assertEqual (diff_vars .where_filter , mock_tdatadiffmodelconfig .where_filter )
11901196 mock_dbt_parser .get_pk_from_model .assert_called_once ()
1197+
1198+ def test_get_diff_vars_custom_db (self ):
1199+ mock_model = Mock ()
1200+ prod_database = "a_prod_db"
1201+ primary_keys = ["a_primary_key" ]
1202+ mock_model .database = "a_dev_db"
1203+ mock_model .schema_ = "a_schema"
1204+ mock_model .config .schema_ = None
1205+ mock_model .config .database = "custom_database"
1206+ mock_model .alias = "a_model_name"
1207+ mock_tdatadiffmodelconfig = Mock ()
1208+ mock_tdatadiffmodelconfig .where_filter = "where"
1209+ mock_tdatadiffmodelconfig .include_columns = ["include" ]
1210+ mock_tdatadiffmodelconfig .exclude_columns = ["exclude" ]
1211+ mock_dbt_parser = Mock ()
1212+ mock_dbt_parser .get_datadiff_model_config .return_value = mock_tdatadiffmodelconfig
1213+ mock_dbt_parser .connection = {}
1214+ mock_dbt_parser .threads = 0
1215+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
1216+ mock_dbt_parser .requires_upper = False
1217+ mock_model .meta = None
1218+
1219+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
1220+
1221+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
1222+ assert diff_vars .prod_path == [mock_model .config .database , mock_model .schema_ , mock_model .alias ]
1223+ assert diff_vars .primary_keys == primary_keys
1224+ assert diff_vars .connection == mock_dbt_parser .connection
1225+ assert diff_vars .threads == mock_dbt_parser .threads
1226+ self .assertEqual (diff_vars .where_filter , mock_tdatadiffmodelconfig .where_filter )
1227+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
0 commit comments