OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
InjectFilterRule.java
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package org.apache.calcite.rel.rules;
18 
19 import org.apache.calcite.plan.RelOptRuleCall;
20 import org.apache.calcite.plan.RelOptTable;
21 import org.apache.calcite.plan.RelRule;
22 import org.apache.calcite.rel.RelNode;
23 import org.apache.calcite.rel.logical.LogicalTableScan;
24 import org.apache.calcite.rel.type.RelDataTypeField;
25 import org.apache.calcite.rex.RexBuilder;
26 import org.apache.calcite.rex.RexNode;
27 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
28 import org.apache.calcite.sql.type.SqlTypeName;
29 import org.apache.calcite.tools.RelBuilder;
30 import org.slf4j.Logger;
31 import org.slf4j.LoggerFactory;
32 
33 import java.util.ArrayList;
34 import java.util.HashSet;
35 import java.util.List;
36 import java.util.Set;
37 
38 public class InjectFilterRule extends RelRule<InjectFilterRule.Config> {
39  // goal: customer entitlements first swipe
40 
41  public static Set<String> visitedMemo = new HashSet<>();
42  final static Logger HEAVYDBLOGGER = LoggerFactory.getLogger(InjectFilterRule.class);
43  final List<Restriction> restrictions;
44 
45  public InjectFilterRule(Config config, List<Restriction> restrictions) {
46  super(config);
47  this.restrictions = restrictions;
48  clearMemo();
49  }
50 
51  void clearMemo() {
52  visitedMemo.clear();
53  }
54 
55  @Override
56  public void onMatch(RelOptRuleCall call) {
57  LogicalTableScan childScanNode = call.rel(0);
58  String scanNodeString = childScanNode.toString();
59  if (visitedMemo.contains(scanNodeString)) {
60  return;
61  } else {
62  visitedMemo.add(scanNodeString);
63  }
64  RelOptTable table = childScanNode.getTable();
65  List<String> qname = table.getQualifiedName();
66 
67  String query_database = null;
68  String query_table = null;
69  if (qname.size() == 2) {
70  query_database = qname.get(0);
71  query_table = qname.get(1);
72  }
73  if (query_database == null || query_database.isEmpty() || query_table == null
74  || query_table.isEmpty()) {
75  throw new RuntimeException(
76  "Restrictions: Expected qualified name as [database, table] but got: "
77  + qname);
78  }
79 
80  ArrayList<RexNode> orList = new ArrayList<RexNode>();
81  RelBuilder builder = call.builder();
82  RexBuilder rBuilder = builder.getRexBuilder();
83  builder = builder.push(childScanNode);
84  boolean found = false;
85  for (Restriction restriction : restrictions) {
86  // Match the database name.
87  String rest_database = restriction.getRestrictionDatabase();
88  if (rest_database != null && !rest_database.isEmpty()
89  && !rest_database.equals(query_database)) {
90  // TODO(sy): Maybe remove the isEmpty() wildcarding in HEAVY.AI 6.0.
91  HEAVYDBLOGGER.debug("RLS row-level security restriction for database "
92  + rest_database + " ignored because this query is on database "
93  + query_database);
94  continue;
95  }
96 
97  // Match the table name.
98  String rest_table = restriction.getRestrictionTable();
99  if (rest_table != null && !rest_table.isEmpty()
100  && !rest_table.equals(query_table)) {
101  // TODO(sy): Maybe remove the isEmpty() wildcarding in HEAVY.AI 6.0.
102  HEAVYDBLOGGER.debug("RLS row-level security restriction for table " + rest_table
103  + " ignored because this query is on table " + query_table);
104  continue;
105  }
106 
107  // Match the column name.
108  RelDataTypeField field = table.getRowType().getField(
109  restriction.getRestrictionColumn(), false, false);
110  if (field == null) {
111  HEAVYDBLOGGER.debug("RLS row-level security restriction for column "
112  + restriction.getRestrictionColumn()
113  + " ignored because column not present in query table " + query_table);
114  continue;
115  }
116 
117  // Generate the RLS row-level security filter for one Restriction.
118  found = true;
119  HEAVYDBLOGGER.debug(
120  "Scan is " + childScanNode.toString() + " TABLE is " + table.toString());
121  HEAVYDBLOGGER.debug("Column " + restriction.getRestrictionColumn()
122  + " exists in table " + table.getQualifiedName());
123 
124  for (String val : restriction.getRestrictionValues()) {
125  HEAVYDBLOGGER.debug("Column is " + restriction.getRestrictionColumn()
126  + " literal is '" + val + "'");
127  RexNode lit;
128  if (SqlTypeName.NUMERIC_TYPES.indexOf(field.getType().getSqlTypeName()) == -1) {
129  if (val.length() < 2 || val.charAt(0) != '\''
130  || val.charAt(val.length() - 1) != '\'') {
131  throw new RuntimeException(
132  "Restrictions: Expected a CREATE POLICY VALUES string with single quotes.");
133  }
134  lit = rBuilder.makeLiteral(
135  val.substring(1, val.length() - 1), field.getType(), false);
136  } else {
137  lit = rBuilder.makeLiteral(Integer.parseInt(val), field.getType(), false);
138  }
139  RexNode rn = builder.call(SqlStdOperatorTable.EQUALS,
140  builder.field(restriction.getRestrictionColumn()),
141  lit);
142  orList.add(rn);
143  }
144  }
145 
146  if (found) {
147  RexNode relOr = builder.call(SqlStdOperatorTable.OR, orList);
148  final RelNode newNode = builder.filter(relOr).build();
149  call.transformTo(newNode);
150  }
151  };
152 
154  public interface Config extends RelRule.Config {
156  EMPTY.withOperandSupplier(b0 -> b0.operand(LogicalTableScan.class).noInputs())
157  .as(Config.class);
158 
159  @Override
161  return new InjectFilterRule(this, null);
162  }
163 
164  default InjectFilterRule toRule(List<Restriction> rests) {
165  return new InjectFilterRule(this, rests);
166  }
167  }
168 }
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:33
default InjectFilterRule toRule(List< Restriction > rests)
InjectFilterRule(Config config, List< Restriction > restrictions)