OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
connection.py
Go to the documentation of this file.
1 """
2 
3 Connect to an HeavyDB database.
4 """
5 from collections import namedtuple
6 from sqlalchemy.engine.url import make_url
7 from thrift.protocol import TBinaryProtocol, TJSONProtocol
8 from thrift.transport import TSocket, TSSLSocket, THttpClient, TTransport
9 from thrift.transport.TSocket import TTransportException
10 from heavydb.thrift.Heavy import Client
11 from heavydb.thrift.ttypes import TDBException
12 
13 from .cursor import Cursor
14 from .exceptions import _translate_exception, OperationalError
15 
16 from ._samlutils import get_saml_response
17 
18 from packaging.version import Version
19 
20 
21 ConnectionInfo = namedtuple(
22  "ConnectionInfo",
23  [
24  'user',
25  'password',
26  'host',
27  'port',
28  'dbname',
29  'protocol',
30  'bin_cert_validate',
31  'bin_ca_certs',
32  ],
33 )
34 
35 
36 def connect(
37  uri=None,
38  user=None,
39  password=None,
40  host=None,
41  port=6274,
42  dbname=None,
43  protocol='binary',
44  sessionid=None,
45  bin_cert_validate=None,
46  bin_ca_certs=None,
47  idpurl=None,
48  idpformusernamefield='username',
49  idpformpasswordfield='password',
50  idpsslverify=True,
51 ):
52  """
53  Create a new Connection.
54 
55  Parameters
56  ----------
57  uri: str
58  user: str
59  password: str
60  host: str
61  port: int
62  dbname: str
63  protocol: {'binary', 'http', 'https'}
64  sessionid: str
65  bin_cert_validate: bool, optional, binary encrypted connection only
66  Whether to continue if there is any certificate error
67  bin_ca_certs: str, optional, binary encrypted connection only
68  Path to the CA certificate file
69  idpurl : str
70  EXPERIMENTAL Enable SAML authentication by providing
71  the logon page of the SAML Identity Provider.
72  idpformusernamefield: str
73  The HTML form ID for the username, defaults to 'username'.
74  idpformpasswordfield: str
75  The HTML form ID for the password, defaults to 'password'.
76  idpsslverify: str
77  Enable / disable certificate checking, defaults to True.
78 
79  Returns
80  -------
81  conn: Connection
82 
83  Examples
84  --------
85  You can either pass a string ``uri``, all the individual components,
86  or an existing sessionid excluding user, password, and database
87 
88  >>> connect('heavydb://admin:HyperInteractive@localhost:6274/heavyai?'
89  ... 'protocol=binary')
90  Connection(mapd://mapd:***@localhost:6274/mapd?protocol=binary)
91 
92  >>> connect(user='admin', password='HyperInteractive', host='localhost',
93  ... port=6274, dbname='heavyai')
94 
95  >>> connect(user='admin', password='HyperInteractive', host='localhost',
96  ... port=443, idpurl='https://sso.localhost/logon',
97  protocol='https')
98 
99  >>> connect(sessionid='XihlkjhdasfsadSDoasdllMweieisdpo', host='localhost',
100  ... port=6273, protocol='http')
101 
102  """
103  return Connection(
104  uri=uri,
105  user=user,
106  password=password,
107  host=host,
108  port=port,
109  dbname=dbname,
110  protocol=protocol,
111  sessionid=sessionid,
112  bin_cert_validate=bin_cert_validate,
113  bin_ca_certs=bin_ca_certs,
114  idpurl=idpurl,
115  idpformusernamefield=idpformusernamefield,
116  idpformpasswordfield=idpformpasswordfield,
117  idpsslverify=idpsslverify,
118  )
119 
120 
121 def _parse_uri(uri):
122  """
123  Parse connection string
124 
125  Parameters
126  ----------
127  uri: str
128  a URI containing connection information
129 
130  Returns
131  -------
132  info: ConnectionInfo
133 
134  Notes
135  ------
136  The URI may include information on
137 
138  - user
139  - password
140  - host
141  - port
142  - dbname
143  - protocol
144  - bin_cert_validate
145  - bin_ca_certs
146  """
147  url = make_url(uri)
148  user = url.username
149  password = url.password
150  host = url.host
151  port = url.port
152  dbname = url.database
153  protocol = url.query.get('protocol', 'binary')
154  bin_cert_validate = url.query.get('bin_cert_validate', None)
155  bin_ca_certs = url.query.get('bin_ca_certs', None)
156 
157  return ConnectionInfo(
158  user,
159  password,
160  host,
161  port,
162  dbname,
163  protocol,
164  bin_cert_validate,
165  bin_ca_certs,
166  )
167 
168 
170  """Connect to your HeavyDB database."""
171 
172  def __init__(
173  self,
174  uri=None,
175  user=None,
176  password=None,
177  host=None,
178  port=6274,
179  dbname=None,
180  protocol='binary',
181  sessionid=None,
182  bin_cert_validate=None,
183  bin_ca_certs=None,
184  idpurl=None,
185  idpformusernamefield='username',
186  idpformpasswordfield='password',
187  idpsslverify=True,
188  ):
189 
190  self.sessionid = None
191  self._closed = 0
192  if sessionid is not None:
193  if any([user, password, uri, dbname, idpurl]):
194  raise TypeError(
195  "Cannot specify sessionid with user, password,"
196  " dbname, uri, or idpurl"
197  )
198  if uri is not None:
199  if not all(
200  [
201  user is None,
202  password is None,
203  host is None,
204  port == 6274,
205  dbname is None,
206  protocol == 'binary',
207  bin_cert_validate is None,
208  bin_ca_certs is None,
209  idpurl is None,
210  ]
211  ):
212  raise TypeError("Cannot specify both URI and other arguments")
213  (
214  user,
215  password,
216  host,
217  port,
218  dbname,
219  protocol,
220  bin_cert_validate,
221  bin_ca_certs,
222  ) = _parse_uri(uri)
223  if host is None:
224  raise TypeError("`host` parameter is required.")
225  if protocol != 'binary' and not all(
226  [bin_cert_validate is None, bin_ca_certs is None]
227  ):
228  raise TypeError(
229  "Cannot specify bin_cert_validate or bin_ca_certs,"
230  " without binary protocol"
231  )
232  if protocol in ("http", "https"):
233  if not host.startswith(protocol):
234  # the THttpClient expects http[s]://localhost
235  host = '{0}://{1}'.format(protocol, host)
236  transport = THttpClient.THttpClient("{}:{}".format(host, port))
237  proto = TJSONProtocol.TJSONProtocol(transport)
238  socket = None
239  elif protocol == "binary":
240  if any([bin_cert_validate is not None, bin_ca_certs]):
241  socket = TSSLSocket.TSSLSocket(
242  host,
243  port,
244  validate=(bin_cert_validate),
245  ca_certs=bin_ca_certs,
246  )
247  else:
248  socket = TSocket.TSocket(host, port)
249  transport = TTransport.TBufferedTransport(socket)
250  proto = TBinaryProtocol.TBinaryProtocolAccelerated(transport)
251  else:
252  raise ValueError(
253  "`protocol` should be one of"
254  " ['http', 'https', 'binary'],"
255  " got {} instead".format(protocol),
256  )
257  self._user = user
258  self._password = password
259  self._host = host
260  self._port = port
261  self._dbname = dbname
262  self._transport = transport
263  self._protocol = protocol
264  self._socket = socket
265  self._tdf = None
266  self._rbc = None
267  try:
268  self._transport.open()
269  except TTransportException as e:
270  if e.NOT_OPEN:
271  err = OperationalError("Could not connect to database")
272  raise err from e
273  else:
274  raise
275  self._client = Client(proto)
276  try:
277  # If a sessionid was passed, we should validate it
278  if sessionid:
279  self._session = sessionid
280  self._client.get_tables(self._session)
281  self.sessionid = sessionid
282  else:
283  if idpurl:
284  self._user = ''
286  username=user,
287  password=password,
288  idpurl=idpurl,
289  userformfield=idpformusernamefield,
290  passwordformfield=idpformpasswordfield,
291  sslverify=idpsslverify,
292  )
293  self._dbname = ''
294  self._idpsslverify = idpsslverify
295  user = self._user
296  password = self._password
297  dbname = self._dbname
298 
299  self._session = self._client.connect(user, password, dbname)
300  except TDBException as e:
301  raise _translate_exception(e) from e
302  except TTransportException:
303  raise ValueError(
304  f"Connection failed with port {port} and "
305  f"protocol '{protocol}'. Try port 6274 for "
306  "protocol == binary or 6273, 6278 or 443 for "
307  "http[s]"
308  )
309 
310  # if HeavyDB version <4.6, raise RuntimeError, as data import can be
311  # incorrect for columnar date loads
312  # Caused by https://github.com/omnisci/pymapd/pull/188
313  semver = self._client.get_version()
314  if Version(semver.split("-")[0]) < Version("4.6"):
315  raise RuntimeError(
316  f"Version {semver} of HeavyDB detected. "
317  "Please use pymapd <0.11. See release notes "
318  "for more details."
319  )
320 
321  def __repr__(self):
322  tpl = (
323  'Connection(heavydb://{user}:***@{host}:{port}/{dbname}?'
324  'protocol={protocol})'
325  )
326  return tpl.format(
327  user=self._user,
328  host=self._host,
329  port=self._port,
330  dbname=self._dbname,
331  protocol=self._protocol,
332  )
333 
334  def __del__(self):
335  self.close()
336 
337  def __enter__(self):
338  return self
339 
340  def __exit__(self, exc_type, exc_val, exc_tb):
341  self.close()
342 
343  @property
344  def closed(self):
345  return self._closed
346 
347  def close(self):
348  """Disconnect from the database unless created with sessionid"""
349  if not self.sessionid and not self._closed:
350  try:
351  self._client.disconnect(self._session)
352  except (TDBException, AttributeError, TypeError):
353  pass
354  self._closed = 1
355  self._rbc = None
356 
357  def commit(self):
358  """This is a noop, as HeavyDB does not provide transactions.
359 
360  Implemented to comply with the DBI specification.
361  """
362  return None
363 
364  def execute(self, operation, parameters=None):
365  """Execute a SQL statement
366 
367  Parameters
368  ----------
369  operation: str
370  A SQL statement to exucute
371 
372  Returns
373  -------
374  c: Cursor
375  """
376  c = Cursor(self)
377  return c.execute(operation, parameters=parameters)
378 
379  def cursor(self):
380  """Create a new :class:`Cursor` object attached to this connection."""
381  return Cursor(self)
382 
383  def __call__(self, *args, **kwargs):
384  """Runtime UDF decorator.
385 
386  The connection object can be applied to a Python function as
387  decorator that will add the function to bending registration
388  list.
389  """
390  try:
391  from rbc.heavydb import RemoteHeavyDB
392  except ImportError:
393  raise ImportError("The 'rbc' package is required for `__call__`")
394  if self._rbc is None:
395  self._rbc = RemoteHeavyDB(
396  user=self._user,
397  password=self._password,
398  host=self._host,
399  port=self._port,
400  dbname=self._dbname,
401  )
402  self._rbc._session_id = self.sessionid
403  return self._rbc(*args, **kwargs)
404 
406  """Register any bending Runtime UDF functions in HeavyDB server.
407 
408  If no Runtime UDFs have been defined, the call to this method
409  is noop.
410  """
411  if self._rbc is not None:
412  self._rbc.register()