1 package com.mapd.calcite.parser;
6 import org.apache.calcite.linq4j.Ord;
7 import org.apache.calcite.rel.type.RelDataType;
8 import org.apache.calcite.rel.type.RelDataTypeField;
9 import org.apache.calcite.sql.SqlBasicCall;
10 import org.apache.calcite.sql.SqlCall;
11 import org.apache.calcite.sql.SqlCallBinding;
12 import org.apache.calcite.sql.SqlFunctionCategory;
13 import org.apache.calcite.sql.SqlNode;
14 import org.apache.calcite.sql.SqlOperandCountRange;
16 import org.apache.calcite.sql.SqlSyntax;
17 import org.apache.calcite.sql.type.SqlOperandCountRanges;
18 import org.apache.calcite.sql.type.SqlOperandTypeChecker;
19 import org.apache.calcite.sql.type.SqlTypeName;
20 import org.apache.calcite.sql.validate.SqlNameMatchers;
22 import java.util.ArrayList;
23 import java.util.HashMap;
24 import java.util.HashSet;
25 import java.util.Iterator;
26 import java.util.List;
28 import java.util.stream.Collectors;
47 SqlCallBinding callBinding,
50 SqlCall permutedCall = callBinding.permutedCall();
51 SqlNode permutedOperand = permutedCall.operand(iFormalOperand);
57 type = callBinding.getValidator().deriveType(
58 callBinding.getScope(), permutedOperand);
59 }
catch (Exception e) {
62 SqlTypeName
typeName = type.getSqlTypeName();
64 if (typeName == SqlTypeName.CURSOR) {
65 SqlCall cursorCall = (SqlCall) permutedOperand;
66 RelDataType cursorType = callBinding.getValidator().deriveType(
67 callBinding.getScope(), cursorCall.operand(0));
70 return tf.getArgTypes().
get(iFormalOperand).getTypeNames().
contains(typeName);
75 Set<ExtTableFunction> candidateOverloads =
new HashSet<ExtTableFunction>(
80 candidateOverloads.removeIf(
81 tf -> tf.getArgTypes().size() != callBinding.getOperandCount());
83 SqlNode[] operandArray =
new SqlNode[callBinding.getCall().getOperandList().size()];
84 for (Ord<SqlNode> arg : Ord.zip(callBinding.getCall().getOperandList())) {
85 operandArray[arg.i] = arg.e;
91 HashMap<ExtTableFunction, SqlCallBinding> candidateBindings =
92 new HashMap<>(candidateOverloads.size());
94 SqlBasicCall newCall =
new SqlBasicCall(
95 tf, operandArray, callBinding.getCall().getParserPosition());
96 SqlCallBinding candidateBinding =
new SqlCallBinding(
97 callBinding.getValidator(), callBinding.getScope(), newCall);
98 candidateBindings.put(tf, candidateBinding);
101 for (
int i = 0; i < operandArray.length; i++) {
103 candidateOverloads.removeIf(tf
105 tf, candidateBindings.get(tf), operandArray[idx], idx));
109 if (candidateOverloads.size() == 0) {
110 if (throwOnFailure) {
111 throw(callBinding.newValidationSignatureError());
118 if (!candidateOverloads.isEmpty()
119 && !candidateOverloads.contains(callBinding.getOperator())) {
121 ((SqlBasicCall) callBinding.getCall()).setOperator(optimal);
129 String formalOperandName = tf.getExtendedParamNames().
get(iFormalOperand);
130 List<ExtensionFunction.ExtArgumentType> formalFieldTypes =
131 tf.getCursorFieldTypes().
get(formalOperandName);
132 List<RelDataTypeField> actualFieldList = actualOperand.getFieldList();
136 if (formalFieldTypes.size() == 0) {
138 "Warning: UDTF has no CURSOR field subtype data. Proceeding assuming CURSOR typechecks.");
144 while (iActual < actualFieldList.size() && iFormal < formalFieldTypes.size()) {
145 ExtensionFunction.ExtArgumentType extType = formalFieldTypes.get(iFormal);
146 SqlTypeName formalType = ExtensionFunction.toSqlTypeName(extType);
147 SqlTypeName actualType = actualFieldList.get(iActual).getValue().getSqlTypeName();
149 if (formalType == SqlTypeName.COLUMN_LIST) {
150 ExtensionFunction.ExtArgumentType colListSubtype =
151 ExtensionFunction.getValueType(extType);
152 SqlTypeName colListType = ExtensionFunction.toSqlTypeName(colListSubtype);
154 if (actualType != colListType) {
159 int numFormalArgumentsLeft = (formalFieldTypes.size() - 1) - iFormal;
160 while (iActual + colListSize
161 < (actualFieldList.size() - numFormalArgumentsLeft)) {
163 actualFieldList.get(iActual + colListSize).getValue().getSqlTypeName();
164 if (actualType != colListType) {
169 iActual += colListSize - 1;
170 }
else if (formalType != actualType) {
177 if (iActual < actualFieldList.size()) {
185 List<SqlOperator> overloads =
new ArrayList<>();
186 opTable.lookupOperatorOverloads(op.getNameAsId(),
187 SqlFunctionCategory.USER_DEFINED_TABLE_FUNCTION,
190 SqlNameMatchers.liberal());
192 return overloads.stream()
194 .map(p -> (ExtTableFunction) p)
195 .collect(Collectors.toList());
199 return SqlOperandCountRanges.any();
204 return String.join(System.lineSeparator() +
"\t",
206 .map(tf -> tf.getExtendedSignature())
207 .collect(Collectors.toList()));
211 return Consistency.NONE;
bool contains(const T &container, const U &element)
final HeavyDBSqlOperatorTable opTable
ExtTableFunctionTypeChecker(HeavyDBSqlOperatorTable opTable)
Consistency getConsistency()
SqlOperandCountRange getOperandCountRange()
List< ExtTableFunction > getOperatorOverloads(SqlOperator op)
boolean doesCursorOperandTypeMatch(ExtTableFunction tf, int iFormalOperand, RelDataType actualOperand)
boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure)
std::string typeName(const T *v)
boolean isOptional(int argIndex)
boolean doesOperandTypeMatch(ExtTableFunction tf, SqlCallBinding callBinding, SqlNode node, int iFormalOperand)
String getAllowedSignatures(SqlOperator op, String opName)