/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Objects;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;

public class AggregateMergeRule
extends RelRule<Config>
implements TransformationRule {
    protected AggregateMergeRule(Config config) {
        super(config);
    }

    @Deprecated
    public AggregateMergeRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).withOperandSupplier(b -> b.exactly(operand)).as(Config.class));
    }

    private boolean isAggregateSupported(AggregateCall aggCall) {
        if (aggCall.isDistinct() || aggCall.hasFilter() || aggCall.isApproximate() || aggCall.getArgList().size() > 1) {
            return false;
        }
        SqlSplittableAggFunction splitter = aggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
        return splitter != null;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate topAgg = (Aggregate)call.rel(0);
        Aggregate bottomAgg = (Aggregate)call.rel(1);
        if (topAgg.getGroupCount() > bottomAgg.getGroupCount()) {
            return;
        }
        ImmutableBitSet bottomGroupSet = bottomAgg.getGroupSet();
        HashMap<Integer, Integer> map2 = new HashMap<Integer, Integer>();
        bottomGroupSet.forEach(v -> map2.put(map2.size(), (Integer)v));
        for (int k : topAgg.getGroupSet()) {
            if (map2.containsKey(k)) continue;
            return;
        }
        ImmutableBitSet topGroupSet = topAgg.getGroupSet().permute(map2);
        if (!bottomGroupSet.contains(topGroupSet)) {
            return;
        }
        boolean hasEmptyGroup = topAgg.getGroupSets().stream().anyMatch(ImmutableBitSet::isEmpty);
        ArrayList<AggregateCall> finalCalls = new ArrayList<AggregateCall>();
        for (AggregateCall topCall : topAgg.getAggCallList()) {
            if (!this.isAggregateSupported(topCall) || topCall.getArgList().size() == 0) {
                return;
            }
            int bottomIndex = topCall.getArgList().get(0) - bottomGroupSet.cardinality();
            if (bottomIndex >= bottomAgg.getAggCallList().size() || bottomIndex < 0) {
                return;
            }
            AggregateCall bottomCall = bottomAgg.getAggCallList().get(bottomIndex);
            if (!this.isAggregateSupported(bottomCall) || bottomCall.getAggregation() == SqlStdOperatorTable.COUNT && topCall.getAggregation().getKind() != SqlKind.SUM0 && hasEmptyGroup) {
                return;
            }
            SqlSplittableAggFunction splitter = Objects.requireNonNull(bottomCall.getAggregation().unwrap(SqlSplittableAggFunction.class));
            AggregateCall finalCall = splitter.merge(topCall, bottomCall);
            if (finalCall == null) {
                return;
            }
            finalCalls.add(finalCall);
        }
        ImmutableList<ImmutableBitSet> newGroupingSets = null;
        if (topAgg.getGroupType() != Aggregate.Group.SIMPLE) {
            newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(topAgg.getGroupSets(), map2));
        }
        Aggregate finalAgg = topAgg.copy(topAgg.getTraitSet(), bottomAgg.getInput(), topGroupSet, newGroupingSets, finalCalls);
        call.transformTo(finalAgg);
    }

    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = EMPTY.withOperandSupplier(b0 -> b0.operand(Aggregate.class).oneInput(b1 -> b1.operand(Aggregate.class).predicate(Aggregate::isSimple).anyInputs())).as(Config.class);

        @Override
        default public AggregateMergeRule toRule() {
            return new AggregateMergeRule(this);
        }
    }
}

