Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367

368

369

370

371

372

373

374

375

376

377

378

379

380

381

382

383

384

385

386

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

405

406

407

408

409

410

411

412

413

414

415

416

417

418

419

420

421

422

423

424

425

426

427

428

429

430

431

432

433

434

435

436

437

438

439

440

441

442

443

444

445

446

447

448

449

450

451

452

453

454

455

456

457

458

459

460

461

462

463

464

465

466

467

468

469

470

471

472

473

474

475

476

477

478

479

480

481

482

############################################################################## 

# 

# Copyright (c) 2008-2012 Jens Vagelpohl and Contributors. All Rights Reserved. 

# 

# This software is subject to the provisions of the Zope Public License, 

# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. 

# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED 

# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 

# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS 

# FOR A PARTICULAR PURPOSE. 

# 

############################################################################## 

""" LDAPConnection: A class modeling a LDAP server connection 

 

Instances of this class offer a simplified API to do searches, insertions, 

deletions or modifications. 

""" 

 

import ldap 

from ldap.dn import str2dn 

from ldap.ldapobject import ReconnectLDAPObject 

import ldapurl 

import logging 

from random import random 

import six 

 

from zope.interface import implementer 

 

from dataflake.cache.simple import LockingSimpleCache 

from dataflake.ldapconnection.interfaces import ILDAPConnection 

from dataflake.ldapconnection.utils import BINARY_ATTRIBUTES 

from dataflake.ldapconnection.utils import dn2str 

from dataflake.ldapconnection.utils import escape_dn 

 

 

default_logger = logging.getLogger('dataflake.ldapconnection') 

connection_cache = LockingSimpleCache() 

_marker = () 

 

 

@implementer(ILDAPConnection) 

class LDAPConnection(object): 

""" LDAPConnection object 

 

See `interfaces.py` for interface documentation. 

""" 

 

def __init__(self, host='', port=389, protocol='ldap', 

c_factory=ReconnectLDAPObject, rdn_attr='', 

bind_dn=b'', bind_pwd='', read_only=False, conn_timeout=-1, 

op_timeout=-1, logger=None, ldap_encoding='UTF-8', 

api_encoding='UTF-8'): 

""" LDAPConnection initialization 

 

Empty values for api_encoding or ldap_encoding mean "use unicode" 

""" 

if isinstance(bind_dn, six.text_type) and ldap_encoding: 

bind_dn = bind_dn.encode(ldap_encoding) 

self.bind_dn = bind_dn 

self.bind_pwd = bind_pwd 

self.read_only = read_only 

self.c_factory = c_factory 

self._logger = logger 

self.ldap_encoding = ldap_encoding 

self.api_encoding = api_encoding 

self.hash = id(self) + random() 

 

self.servers = {} 

if host: 

self.addServer(host, port, protocol, conn_timeout, op_timeout) 

 

def logger(self): 

""" Get the logger 

""" 

if self._logger is None: 

return default_logger 

 

return self._logger 

 

def addServer(self, host, port, protocol, conn_timeout=-1, op_timeout=-1): 

""" Add a server definition to the list of servers used 

""" 

start_tls = False 

if protocol == 'ldaptls': 

protocol = 'ldap' 

start_tls = True 

hp = '%s:%s' % (host, port) 

conn = ldapurl.LDAPUrl(urlscheme=protocol, hostport=hp) 

server_url = conn.initializeUrl() 

self.servers[server_url] = {'url': server_url, 

'conn_timeout': conn_timeout, 

'op_timeout': op_timeout, 

'start_tls': start_tls} 

 

def removeServer(self, host, port, protocol): 

""" Remove a server definition from the list of servers used 

""" 

hp = '%s:%s' % (host, port) 

conn = ldapurl.LDAPUrl(urlscheme=protocol, hostport=hp) 

server_url = conn.initializeUrl() 

if server_url in self.servers.keys(): 

del self.servers[server_url] 

 

def connect(self, bind_dn=None, bind_pwd=None): 

""" initialize an ldap server connection 

 

This method returns an instance of the underlying `pyldap` 

connection class. It does not need to be called explicitly, all 

other operations call it implicitly. 

""" 

if not self.servers: 

raise RuntimeError('No servers defined') 

 

if bind_dn is None: 

bind_dn = escape_dn(self._encode_incoming(self.bind_dn), 

self.ldap_encoding) 

bind_pwd = self._encode_incoming(self.bind_pwd) 

else: 

bind_dn = escape_dn(self._encode_incoming(bind_dn), 

self.ldap_encoding) 

bind_pwd = self._encode_incoming(bind_pwd) 

 

conn = self._getConnection() 

if conn is None: 

for server in self.servers.values(): 

try: 

conn = self._connect(server['url'], 

conn_timeout=server['conn_timeout'], 

op_timeout=server['op_timeout']) 

if server.get('start_tls', None): 

conn.start_tls_s() 

break 

except (ldap.SERVER_DOWN, ldap.TIMEOUT, ldap.LOCAL_ERROR) as e: 

conn = None 

exc = e 

continue 

 

if conn is None: 

msg = 'Failure connecting, last attempt: %s (%s)' % ( 

server['url'], str(exc) or 'no exception') 

self.logger().critical(msg, exc_info=1) 

 

if exc: 

raise exc 

 

connection_cache.set(self.hash, conn) 

 

last_bind = getattr(conn, '_last_bind', None) 

if not last_bind or \ 

last_bind[1][0] != bind_dn or \ 

last_bind[1][1] != bind_pwd: 

conn.simple_bind_s(bind_dn, bind_pwd) 

 

return conn 

 

def _getConnection(self): 

""" Private helper to get my connection out of the cache 

""" 

return connection_cache.get(self.hash) 

 

def _connect(self, connection_string, conn_timeout=5, op_timeout=-1): 

""" Factored out to allow usage by other pieces 

 

user_dn is assumed to have been encoded/escaped correctly 

""" 

connection = self.c_factory(connection_string) 

 

# Deny auto-chasing of referrals to be safe, we handle them instead 

try: 

connection.set_option(ldap.OPT_REFERRALS, ldap.DEREF_NEVER) 

except ldap.LDAPError: # Cannot set referrals, so do nothing 

pass 

 

# Set the connection timeout 

if conn_timeout > 0: 

connection.set_option(ldap.OPT_NETWORK_TIMEOUT, conn_timeout) 

 

# Set the operations timeout 

if op_timeout > 0: 

connection.timeout = op_timeout 

 

return connection 

 

def disconnect(self): 

""" Unbind the connection and invalidate the cache 

""" 

conn = self._getConnection() 

if conn is not None: 

connection_cache.invalidate(self.hash) 

conn.unbind_s() 

 

def search(self, base, scope=ldap.SCOPE_SUBTREE, fltr='(objectClass=*)', 

attrs=None, convert_filter=True, bind_dn=None, bind_pwd=None, 

raw=False): 

""" Search for entries in the database 

""" 

result = {'size': 0, 'results': [], 'exception': ''} 

if convert_filter: 

fltr = self._encode_incoming(fltr) 

base = escape_dn(self._encode_incoming(base), 

self.ldap_encoding) 

connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd) 

 

try: 

res = connection.search_s(base, scope, fltr, attrs) 

except ldap.PARTIAL_RESULTS: 

res_type, res = connection.result(all=0) 

except ldap.REFERRAL as e: 

connection = self._handle_referral(e) 

 

try: 

res = connection.search_s(base, scope, fltr, attrs) 

except ldap.PARTIAL_RESULTS: 

res_type, res = connection.result(all=0) 

 

for rec_dn, rec_dict in res: 

# When used against Active Directory, "rec_dict" may not be 

# be a dictionary in some cases (instead, it can be a list) 

# An example of a useless "res" entry that can be ignored 

# from AD is 

# (None, ['ldap://ForestDnsZones.PORTAL.LOCAL/DC=ForestDnsZones,\ 

# DC=PORTAL,DC=LOCAL']) 

# This appears to be some sort of internal referral, but 

# we can't handle it, so we need to skip over it. 

try: 

items = list(rec_dict.items()) 

except AttributeError: 

# 'items' not found on rec_dict 

continue 

 

if raw: 

rec_dict['dn'] = rec_dn 

else: 

for key, value in items: 

if key == b'dn': 

del rec_dict[key] 

elif key.lower() not in BINARY_ATTRIBUTES: 

if isinstance(value, (list, tuple)): 

for i in range(len(value)): 

value[i] = self._encode_outgoing(value[i]) 

else: 

rec_dict[key] = self._encode_outgoing(value) 

 

rec_dict['dn'] = self._encode_outgoing(rec_dn) 

 

result['results'].append(rec_dict) 

result['size'] += 1 

 

return result 

 

def insert(self, base, rdn, attrs=None, bind_dn=None, bind_pwd=None): 

""" Insert a new record 

 

attrs is expected to be a mapping where the value may be a string 

or a sequence of strings. 

Multiple values may be expressed as a single string if the values 

are semicolon-delimited. 

Values can be marked as binary values, meaning they are not encoded 

as UTF-8, by appending ';binary' to the key. 

""" 

self._complainIfReadOnly() 

base = escape_dn(self._encode_incoming(base), self.ldap_encoding) 

rdn = escape_dn(self._encode_incoming(rdn), self.ldap_encoding) 

 

dn = rdn + b',' + base 

attribute_list = [] 

attrs = attrs and attrs or {} 

 

for attr_key, values in attrs.items(): 

if attr_key.endswith(';binary'): 

is_binary = True 

attr_key = attr_key[:-7] 

else: 

is_binary = False 

 

if not isinstance(attr_key, six.binary_type): 

attr_key = self._encode_incoming(attr_key) 

 

if isinstance(values, six.string_types) and not is_binary: 

values = [x.strip() for x in values.split(';')] 

elif isinstance(values, six.binary_type) and not is_binary: 

values = [x.strip() for x in values.split(b';')] 

 

if values != ['']: 

if not is_binary: 

values = [self._encode_incoming(x) for x in values] 

attribute_list.append((attr_key, values)) 

 

try: 

connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd) 

connection.add_s(dn, attribute_list) 

except ldap.REFERRAL as e: 

connection = self._handle_referral(e) 

connection.add_s(dn, attribute_list) 

 

def delete(self, dn, bind_dn=None, bind_pwd=None): 

""" Delete a record 

""" 

self._complainIfReadOnly() 

 

dn = escape_dn(self._encode_incoming(dn), self.ldap_encoding) 

 

try: 

connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd) 

connection.delete_s(dn) 

except ldap.REFERRAL as e: 

connection = self._handle_referral(e) 

connection.delete_s(dn) 

 

def modify(self, dn, mod_type=None, attrs=None, bind_dn=None, 

bind_pwd=None): 

""" Modify a record 

""" 

self._complainIfReadOnly() 

 

unescaped_dn = self._encode_incoming(dn) 

dn = escape_dn(unescaped_dn, self.ldap_encoding) 

res = self.search(base=unescaped_dn, scope=ldap.SCOPE_BASE, 

bind_dn=bind_dn, bind_pwd=bind_pwd, raw=True) 

attrs = attrs and attrs or {} 

cur_rec = res['results'][0] 

mod_list = [] 

 

for key, values in list(attrs.items()): 

 

if key.endswith(';binary'): 

key = key[:-7] 

is_binary = True 

else: 

is_binary = False 

 

if not isinstance(key, six.binary_type): 

key = self._encode_incoming(key) 

 

if not is_binary: 

if isinstance(values, six.string_types): 

values = [self._encode_incoming(x) for x in 

values.split(';')] 

else: 

values = [self._encode_incoming(x) for x in values] 

 

if isinstance(key, six.text_type): 

key = self._encode_incoming(key) 

 

if mod_type is None: 

if key not in cur_rec and values != [b'']: 

mod_list.append((ldap.MOD_ADD, key, values)) 

elif cur_rec.get(key, [b'']) != values and \ 

values not in ([b''], []): 

mod_list.append((ldap.MOD_REPLACE, key, values)) 

elif key in cur_rec and values in ([b''], []): 

mod_list.append((ldap.MOD_DELETE, key, None)) 

elif mod_type in (ldap.MOD_ADD, ldap.MOD_DELETE) and \ 

values == [b'']: 

continue 

elif mod_type == ldap.MOD_DELETE and \ 

set(values).difference(set(cur_rec.get(key, []))): 

continue 

else: 

mod_list.append((mod_type, key, values)) 

 

attrs[key] = values 

 

try: 

connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd) 

 

dn_parts = str2dn(dn) 

clean_dn_parts = [] 

for dn_part in dn_parts: 

for (attr_name, attr_val, flag) in dn_part: 

if isinstance(attr_name, six.text_type): 

attr_name = self._encode_incoming(attr_name) 

if isinstance(attr_val, six.text_type): 

attr_val = self._encode_incoming(attr_val) 

clean_dn_parts.append([(attr_name, attr_val, flag)]) 

 

rdn_attr = clean_dn_parts[0][0][0] 

raw_rdn = attrs.get(rdn_attr, '') 

if isinstance(raw_rdn, six.string_types): 

raw_rdn = [raw_rdn] 

new_rdn = raw_rdn[0] 

 

if new_rdn: 

rdn_value = self._encode_incoming(new_rdn) 

if rdn_value != cur_rec.get(rdn_attr)[0]: 

clean_dn_parts[0] = [(rdn_attr, rdn_value, 1)] 

dn_parts[0] = [(rdn_attr, raw_rdn[0], 1)] 

raw_utf8_rdn = rdn_attr + b'=' + rdn_value 

new_rdn = escape_dn(raw_utf8_rdn, self.ldap_encoding) 

connection.modrdn_s(dn, new_rdn) 

dn = dn2str(clean_dn_parts) 

 

if mod_list: 

connection.modify_s(dn, mod_list) 

else: 

debug_msg = 'Nothing to modify: %s' % dn 

self.logger().debug(debug_msg) 

 

except ldap.REFERRAL as e: 

connection = self._handle_referral(e) 

connection.modify_s(dn, mod_list) 

 

def _handle_referral(self, exception): 

""" Handle a referral specified in the passed-in exception 

""" 

payload = exception.args[0] 

info = payload.get('info') 

ldap_url = info[info.find('ldap'):] 

 

if ldapurl.isLDAPUrl(ldap_url): 

conn_str = ldapurl.LDAPUrl(ldap_url).initializeUrl() 

conn = self._connect(conn_str) 

conn.simple_bind_s(self._encode_incoming(self.bind_dn), 

self._encode_incoming(self.bind_pwd)) 

return conn 

else: 

raise ldap.CONNECT_ERROR('Bad referral "%s"' % str(exception)) 

 

def _complainIfReadOnly(self): 

""" Raise RuntimeError if the connection is set to `read-only` 

 

This method should be called before any directory tree modfication 

""" 

if self.read_only: 

raise RuntimeError( 

'Running in read-only mode, directory modifications disabled') 

 

def _encode_incoming(self, value): 

""" Encode a string value to the encoding set for the LDAP server 

 

- if "value" is unicode, it will be encoded to self.ldap_encoding, but 

only if self.ldap_encoding is set. 

- if "value" is not unicode, it is assumed to be encoded as 

self.api_encoding. It is decoded and encoded to self.ldap_encoding 

if self.ldap_encoding is set, unless self.api_encoding and 

self.ldap_encoding are identical. In that case the passed-in value 

is handed back unchanged. 

""" 

if value is None: 

return None 

 

elif isinstance(value, six.text_type): 

if self.ldap_encoding: 

value = value.encode(self.ldap_encoding) 

 

else: 

if self.api_encoding != self.ldap_encoding: 

if self.api_encoding: 

value = value.decode(self.api_encoding) 

 

if self.ldap_encoding: 

value = value.encode(self.ldap_encoding) 

 

return value 

 

def _encode_outgoing(self, value): 

""" Encode a string value to the API encoding 

 

- if "value" is unicode, it will be encoded to self.api_encoding, but 

only if self.api_encoding is set. 

- if "value" is not unicode, it is assumed to be encoded as 

self.ldap_encoding. It is decoded and encoded to self.api_encoding 

if self.api_encoding is set, unless self.ldap_encoding and 

self.api_encoding are identical. In that case the passed-in value 

is handed back unchanged. 

""" 

if value is None: 

return None 

 

elif isinstance(value, six.text_type): 

if self.api_encoding: 

value = value.encode(self.api_encoding) 

 

else: 

if self.api_encoding != self.ldap_encoding: 

if self.ldap_encoding: 

value = value.decode(self.ldap_encoding) 

 

if self.api_encoding: 

value = value.encode(self.api_encoding) 

 

return value