diff --git a/catalog/src/main/java/org/killbill/billing/catalog/rules/DefaultCase.java b/catalog/src/main/java/org/killbill/billing/catalog/rules/DefaultCase.java
index f348183..fe34af6 100644
--- a/catalog/src/main/java/org/killbill/billing/catalog/rules/DefaultCase.java
+++ b/catalog/src/main/java/org/killbill/billing/catalog/rules/DefaultCase.java
@@ -22,7 +22,9 @@ import org.killbill.billing.catalog.DefaultProduct;
import org.killbill.billing.catalog.StandaloneCatalog;
import org.killbill.billing.catalog.api.BillingPeriod;
import org.killbill.billing.catalog.api.CatalogApiException;
+import org.killbill.billing.catalog.api.Plan;
import org.killbill.billing.catalog.api.PlanSpecifier;
+import org.killbill.billing.catalog.api.Product;
import org.killbill.billing.catalog.api.ProductCategory;
import org.killbill.billing.catalog.api.StaticCatalog;
import org.killbill.xmlloader.ValidatingConfig;
@@ -49,6 +51,7 @@ public abstract class DefaultCase<T> extends ValidatingConfig<StandaloneCatalog>
protected boolean satisfiesCase(final PlanSpecifier planPhase, final StaticCatalog c) throws CatalogApiException {
return (getProduct() == null || getProduct().equals(c.findCurrentProduct(planPhase.getProductName()))) &&
+ (getProductCategory() == null || getProductCategory().equals(c.findCurrentProduct(planPhase.getProductName()).getCategory())) &&
(getBillingPeriod() == null || getBillingPeriod().equals(planPhase.getBillingPeriod())) &&
(getPriceList() == null || getPriceList().equals(c.findCurrentPricelist(planPhase.getPriceListName())));
}
diff --git a/catalog/src/main/java/org/killbill/billing/catalog/rules/DefaultCaseChange.java b/catalog/src/main/java/org/killbill/billing/catalog/rules/DefaultCaseChange.java
index 9955ed8..525a5d7 100644
--- a/catalog/src/main/java/org/killbill/billing/catalog/rules/DefaultCaseChange.java
+++ b/catalog/src/main/java/org/killbill/billing/catalog/rules/DefaultCaseChange.java
@@ -27,6 +27,7 @@ import org.killbill.billing.catalog.StandaloneCatalog;
import org.killbill.billing.catalog.api.BillingPeriod;
import org.killbill.billing.catalog.api.CatalogApiException;
import org.killbill.billing.catalog.api.PhaseType;
+import org.killbill.billing.catalog.api.Plan;
import org.killbill.billing.catalog.api.PlanPhaseSpecifier;
import org.killbill.billing.catalog.api.PlanSpecifier;
import org.killbill.billing.catalog.api.PriceList;
@@ -73,18 +74,19 @@ public abstract class DefaultCaseChange<T> extends ValidatingConfig<StandaloneCa
protected abstract T getResult();
-
-
public T getResult(final PlanPhaseSpecifier from,
final PlanSpecifier to, final StaticCatalog catalog) throws CatalogApiException {
+
if (
(phaseType == null || from.getPhaseType() == phaseType) &&
- (fromProduct == null || fromProduct.equals(catalog.findCurrentProduct(from.getProductName()))) &&
- (fromBillingPeriod == null || fromBillingPeriod.equals(from.getBillingPeriod())) &&
- (toProduct == null || toProduct.equals(catalog.findCurrentProduct(to.getProductName()))) &&
- (toBillingPeriod == null || toBillingPeriod.equals(to.getBillingPeriod())) &&
- (fromPriceList == null || fromPriceList.equals(catalog.findCurrentPricelist(from.getPriceListName()))) &&
- (toPriceList == null || toPriceList.equals(catalog.findCurrentPricelist(to.getPriceListName())))
+ (fromProduct == null || fromProduct.equals(catalog.findCurrentProduct(from.getProductName()))) &&
+ (fromProductCategory == null || fromProductCategory.equals(catalog.findCurrentProduct(from.getProductName()).getCategory())) &&
+ (fromBillingPeriod == null || fromBillingPeriod.equals(from.getBillingPeriod())) &&
+ (this.toProduct == null || this.toProduct.equals(catalog.findCurrentProduct(to.getProductName()))) &&
+ (toProductCategory == null || toProductCategory.equals(catalog.findCurrentProduct(to.getProductName()).getCategory())) &&
+ (toBillingPeriod == null || toBillingPeriod.equals(to.getBillingPeriod())) &&
+ (fromPriceList == null || fromPriceList.equals(catalog.findCurrentPricelist(from.getPriceListName()))) &&
+ (toPriceList == null || toPriceList.equals(catalog.findCurrentPricelist(to.getPriceListName())))
) {
return getResult();
}