1# Copyright (C) 2005 Martin v. Löwis
2# Licensed to PSF under a Contributor Agreement.
3from _msi import *
4import fnmatch
5import os
6import re
7import string
8import sys
9import warnings
10
11warnings._deprecated(__name__, remove=(3, 13))
12
13AMD64 = "AMD64" in sys.version
14# Keep msilib.Win64 around to preserve backwards compatibility.
15Win64 = AMD64
16
17# Partially taken from Wine
18datasizemask=      0x00ff
19type_valid=        0x0100
20type_localizable=  0x0200
21
22typemask=          0x0c00
23type_long=         0x0000
24type_short=        0x0400
25type_string=       0x0c00
26type_binary=       0x0800
27
28type_nullable=     0x1000
29type_key=          0x2000
30# XXX temporary, localizable?
31knownbits = datasizemask | type_valid | type_localizable | \
32            typemask | type_nullable | type_key
33
34class Table:
35    def __init__(self, name):
36        self.name = name
37        self.fields = []
38
39    def add_field(self, index, name, type):
40        self.fields.append((index,name,type))
41
42    def sql(self):
43        fields = []
44        keys = []
45        self.fields.sort()
46        fields = [None]*len(self.fields)
47        for index, name, type in self.fields:
48            index -= 1
49            unk = type & ~knownbits
50            if unk:
51                print("%s.%s unknown bits %x" % (self.name, name, unk))
52            size = type & datasizemask
53            dtype = type & typemask
54            if dtype == type_string:
55                if size:
56                    tname="CHAR(%d)" % size
57                else:
58                    tname="CHAR"
59            elif dtype == type_short:
60                assert size==2
61                tname = "SHORT"
62            elif dtype == type_long:
63                assert size==4
64                tname="LONG"
65            elif dtype == type_binary:
66                assert size==0
67                tname="OBJECT"
68            else:
69                tname="unknown"
70                print("%s.%sunknown integer type %d" % (self.name, name, size))
71            if type & type_nullable:
72                flags = ""
73            else:
74                flags = " NOT NULL"
75            if type & type_localizable:
76                flags += " LOCALIZABLE"
77            fields[index] = "`%s` %s%s" % (name, tname, flags)
78            if type & type_key:
79                keys.append("`%s`" % name)
80        fields = ", ".join(fields)
81        keys = ", ".join(keys)
82        return "CREATE TABLE %s (%s PRIMARY KEY %s)" % (self.name, fields, keys)
83
84    def create(self, db):
85        v = db.OpenView(self.sql())
86        v.Execute(None)
87        v.Close()
88
89class _Unspecified:pass
90def change_sequence(seq, action, seqno=_Unspecified, cond = _Unspecified):
91    "Change the sequence number of an action in a sequence list"
92    for i in range(len(seq)):
93        if seq[i][0] == action:
94            if cond is _Unspecified:
95                cond = seq[i][1]
96            if seqno is _Unspecified:
97                seqno = seq[i][2]
98            seq[i] = (action, cond, seqno)
99            return
100    raise ValueError("Action not found in sequence")
101
102def add_data(db, table, values):
103    v = db.OpenView("SELECT * FROM `%s`" % table)
104    count = v.GetColumnInfo(MSICOLINFO_NAMES).GetFieldCount()
105    r = CreateRecord(count)
106    for value in values:
107        assert len(value) == count, value
108        for i in range(count):
109            field = value[i]
110            if isinstance(field, int):
111                r.SetInteger(i+1,field)
112            elif isinstance(field, str):
113                r.SetString(i+1,field)
114            elif field is None:
115                pass
116            elif isinstance(field, Binary):
117                r.SetStream(i+1, field.name)
118            else:
119                raise TypeError("Unsupported type %s" % field.__class__.__name__)
120        try:
121            v.Modify(MSIMODIFY_INSERT, r)
122        except Exception:
123            raise MSIError("Could not insert "+repr(values)+" into "+table)
124
125        r.ClearData()
126    v.Close()
127
128
129def add_stream(db, name, path):
130    v = db.OpenView("INSERT INTO _Streams (Name, Data) VALUES ('%s', ?)" % name)
131    r = CreateRecord(1)
132    r.SetStream(1, path)
133    v.Execute(r)
134    v.Close()
135
136def init_database(name, schema,
137                  ProductName, ProductCode, ProductVersion,
138                  Manufacturer):
139    try:
140        os.unlink(name)
141    except OSError:
142        pass
143    ProductCode = ProductCode.upper()
144    # Create the database
145    db = OpenDatabase(name, MSIDBOPEN_CREATE)
146    # Create the tables
147    for t in schema.tables:
148        t.create(db)
149    # Fill the validation table
150    add_data(db, "_Validation", schema._Validation_records)
151    # Initialize the summary information, allowing atmost 20 properties
152    si = db.GetSummaryInformation(20)
153    si.SetProperty(PID_TITLE, "Installation Database")
154    si.SetProperty(PID_SUBJECT, ProductName)
155    si.SetProperty(PID_AUTHOR, Manufacturer)
156    if AMD64:
157        si.SetProperty(PID_TEMPLATE, "x64;1033")
158    else:
159        si.SetProperty(PID_TEMPLATE, "Intel;1033")
160    si.SetProperty(PID_REVNUMBER, gen_uuid())
161    si.SetProperty(PID_WORDCOUNT, 2) # long file names, compressed, original media
162    si.SetProperty(PID_PAGECOUNT, 200)
163    si.SetProperty(PID_APPNAME, "Python MSI Library")
164    # XXX more properties
165    si.Persist()
166    add_data(db, "Property", [
167        ("ProductName", ProductName),
168        ("ProductCode", ProductCode),
169        ("ProductVersion", ProductVersion),
170        ("Manufacturer", Manufacturer),
171        ("ProductLanguage", "1033")])
172    db.Commit()
173    return db
174
175def add_tables(db, module):
176    for table in module.tables:
177        add_data(db, table, getattr(module, table))
178
179def make_id(str):
180    identifier_chars = string.ascii_letters + string.digits + "._"
181    str = "".join([c if c in identifier_chars else "_" for c in str])
182    if str[0] in (string.digits + "."):
183        str = "_" + str
184    assert re.match("^[A-Za-z_][A-Za-z0-9_.]*$", str), "FILE"+str
185    return str
186
187def gen_uuid():
188    return "{"+UuidCreate().upper()+"}"
189
190class CAB:
191    def __init__(self, name):
192        self.name = name
193        self.files = []
194        self.filenames = set()
195        self.index = 0
196
197    def gen_id(self, file):
198        logical = _logical = make_id(file)
199        pos = 1
200        while logical in self.filenames:
201            logical = "%s.%d" % (_logical, pos)
202            pos += 1
203        self.filenames.add(logical)
204        return logical
205
206    def append(self, full, file, logical):
207        if os.path.isdir(full):
208            return
209        if not logical:
210            logical = self.gen_id(file)
211        self.index += 1
212        self.files.append((full, logical))
213        return self.index, logical
214
215    def commit(self, db):
216        from tempfile import mktemp
217        filename = mktemp()
218        FCICreate(filename, self.files)
219        add_data(db, "Media",
220                [(1, self.index, None, "#"+self.name, None, None)])
221        add_stream(db, self.name, filename)
222        os.unlink(filename)
223        db.Commit()
224
225_directories = set()
226class Directory:
227    def __init__(self, db, cab, basedir, physical, _logical, default, componentflags=None):
228        """Create a new directory in the Directory table. There is a current component
229        at each point in time for the directory, which is either explicitly created
230        through start_component, or implicitly when files are added for the first
231        time. Files are added into the current component, and into the cab file.
232        To create a directory, a base directory object needs to be specified (can be
233        None), the path to the physical directory, and a logical directory name.
234        Default specifies the DefaultDir slot in the directory table. componentflags
235        specifies the default flags that new components get."""
236        index = 1
237        _logical = make_id(_logical)
238        logical = _logical
239        while logical in _directories:
240            logical = "%s%d" % (_logical, index)
241            index += 1
242        _directories.add(logical)
243        self.db = db
244        self.cab = cab
245        self.basedir = basedir
246        self.physical = physical
247        self.logical = logical
248        self.component = None
249        self.short_names = set()
250        self.ids = set()
251        self.keyfiles = {}
252        self.componentflags = componentflags
253        if basedir:
254            self.absolute = os.path.join(basedir.absolute, physical)
255            blogical = basedir.logical
256        else:
257            self.absolute = physical
258            blogical = None
259        add_data(db, "Directory", [(logical, blogical, default)])
260
261    def start_component(self, component = None, feature = None, flags = None, keyfile = None, uuid=None):
262        """Add an entry to the Component table, and make this component the current for this
263        directory. If no component name is given, the directory name is used. If no feature
264        is given, the current feature is used. If no flags are given, the directory's default
265        flags are used. If no keyfile is given, the KeyPath is left null in the Component
266        table."""
267        if flags is None:
268            flags = self.componentflags
269        if uuid is None:
270            uuid = gen_uuid()
271        else:
272            uuid = uuid.upper()
273        if component is None:
274            component = self.logical
275        self.component = component
276        if AMD64:
277            flags |= 256
278        if keyfile:
279            keyid = self.cab.gen_id(keyfile)
280            self.keyfiles[keyfile] = keyid
281        else:
282            keyid = None
283        add_data(self.db, "Component",
284                        [(component, uuid, self.logical, flags, None, keyid)])
285        if feature is None:
286            feature = current_feature
287        add_data(self.db, "FeatureComponents",
288                        [(feature.id, component)])
289
290    def make_short(self, file):
291        oldfile = file
292        file = file.replace('+', '_')
293        file = ''.join(c for c in file if not c in r' "/\[]:;=,')
294        parts = file.split(".")
295        if len(parts) > 1:
296            prefix = "".join(parts[:-1]).upper()
297            suffix = parts[-1].upper()
298            if not prefix:
299                prefix = suffix
300                suffix = None
301        else:
302            prefix = file.upper()
303            suffix = None
304        if len(parts) < 3 and len(prefix) <= 8 and file == oldfile and (
305                                                not suffix or len(suffix) <= 3):
306            if suffix:
307                file = prefix+"."+suffix
308            else:
309                file = prefix
310        else:
311            file = None
312        if file is None or file in self.short_names:
313            prefix = prefix[:6]
314            if suffix:
315                suffix = suffix[:3]
316            pos = 1
317            while 1:
318                if suffix:
319                    file = "%s~%d.%s" % (prefix, pos, suffix)
320                else:
321                    file = "%s~%d" % (prefix, pos)
322                if file not in self.short_names: break
323                pos += 1
324                assert pos < 10000
325                if pos in (10, 100, 1000):
326                    prefix = prefix[:-1]
327        self.short_names.add(file)
328        assert not re.search(r'[\?|><:/*"+,;=\[\]]', file) # restrictions on short names
329        return file
330
331    def add_file(self, file, src=None, version=None, language=None):
332        """Add a file to the current component of the directory, starting a new one
333        if there is no current component. By default, the file name in the source
334        and the file table will be identical. If the src file is specified, it is
335        interpreted relative to the current directory. Optionally, a version and a
336        language can be specified for the entry in the File table."""
337        if not self.component:
338            self.start_component(self.logical, current_feature, 0)
339        if not src:
340            # Allow relative paths for file if src is not specified
341            src = file
342            file = os.path.basename(file)
343        absolute = os.path.join(self.absolute, src)
344        assert not re.search(r'[\?|><:/*]"', file) # restrictions on long names
345        if file in self.keyfiles:
346            logical = self.keyfiles[file]
347        else:
348            logical = None
349        sequence, logical = self.cab.append(absolute, file, logical)
350        assert logical not in self.ids
351        self.ids.add(logical)
352        short = self.make_short(file)
353        full = "%s|%s" % (short, file)
354        filesize = os.stat(absolute).st_size
355        # constants.msidbFileAttributesVital
356        # Compressed omitted, since it is the database default
357        # could add r/o, system, hidden
358        attributes = 512
359        add_data(self.db, "File",
360                        [(logical, self.component, full, filesize, version,
361                         language, attributes, sequence)])
362        #if not version:
363        #    # Add hash if the file is not versioned
364        #    filehash = FileHash(absolute, 0)
365        #    add_data(self.db, "MsiFileHash",
366        #             [(logical, 0, filehash.IntegerData(1),
367        #               filehash.IntegerData(2), filehash.IntegerData(3),
368        #               filehash.IntegerData(4))])
369        # Automatically remove .pyc files on uninstall (2)
370        # XXX: adding so many RemoveFile entries makes installer unbelievably
371        # slow. So instead, we have to use wildcard remove entries
372        if file.endswith(".py"):
373            add_data(self.db, "RemoveFile",
374                      [(logical+"c", self.component, "%sC|%sc" % (short, file),
375                        self.logical, 2),
376                       (logical+"o", self.component, "%sO|%so" % (short, file),
377                        self.logical, 2)])
378        return logical
379
380    def glob(self, pattern, exclude = None):
381        """Add a list of files to the current component as specified in the
382        glob pattern. Individual files can be excluded in the exclude list."""
383        try:
384            files = os.listdir(self.absolute)
385        except OSError:
386            return []
387        if pattern[:1] != '.':
388            files = (f for f in files if f[0] != '.')
389        files = fnmatch.filter(files, pattern)
390        for f in files:
391            if exclude and f in exclude: continue
392            self.add_file(f)
393        return files
394
395    def remove_pyc(self):
396        "Remove .pyc files on uninstall"
397        add_data(self.db, "RemoveFile",
398                 [(self.component+"c", self.component, "*.pyc", self.logical, 2)])
399
400class Binary:
401    def __init__(self, fname):
402        self.name = fname
403    def __repr__(self):
404        return 'msilib.Binary(os.path.join(dirname,"%s"))' % self.name
405
406class Feature:
407    def __init__(self, db, id, title, desc, display, level = 1,
408                 parent=None, directory = None, attributes=0):
409        self.id = id
410        if parent:
411            parent = parent.id
412        add_data(db, "Feature",
413                        [(id, parent, title, desc, display,
414                          level, directory, attributes)])
415    def set_current(self):
416        global current_feature
417        current_feature = self
418
419class Control:
420    def __init__(self, dlg, name):
421        self.dlg = dlg
422        self.name = name
423
424    def event(self, event, argument, condition = "1", ordering = None):
425        add_data(self.dlg.db, "ControlEvent",
426                 [(self.dlg.name, self.name, event, argument,
427                   condition, ordering)])
428
429    def mapping(self, event, attribute):
430        add_data(self.dlg.db, "EventMapping",
431                 [(self.dlg.name, self.name, event, attribute)])
432
433    def condition(self, action, condition):
434        add_data(self.dlg.db, "ControlCondition",
435                 [(self.dlg.name, self.name, action, condition)])
436
437class RadioButtonGroup(Control):
438    def __init__(self, dlg, name, property):
439        self.dlg = dlg
440        self.name = name
441        self.property = property
442        self.index = 1
443
444    def add(self, name, x, y, w, h, text, value = None):
445        if value is None:
446            value = name
447        add_data(self.dlg.db, "RadioButton",
448                 [(self.property, self.index, value,
449                   x, y, w, h, text, None)])
450        self.index += 1
451
452class Dialog:
453    def __init__(self, db, name, x, y, w, h, attr, title, first, default, cancel):
454        self.db = db
455        self.name = name
456        self.x, self.y, self.w, self.h = x,y,w,h
457        add_data(db, "Dialog", [(name, x,y,w,h,attr,title,first,default,cancel)])
458
459    def control(self, name, type, x, y, w, h, attr, prop, text, next, help):
460        add_data(self.db, "Control",
461                 [(self.name, name, type, x, y, w, h, attr, prop, text, next, help)])
462        return Control(self, name)
463
464    def text(self, name, x, y, w, h, attr, text):
465        return self.control(name, "Text", x, y, w, h, attr, None,
466                     text, None, None)
467
468    def bitmap(self, name, x, y, w, h, text):
469        return self.control(name, "Bitmap", x, y, w, h, 1, None, text, None, None)
470
471    def line(self, name, x, y, w, h):
472        return self.control(name, "Line", x, y, w, h, 1, None, None, None, None)
473
474    def pushbutton(self, name, x, y, w, h, attr, text, next):
475        return self.control(name, "PushButton", x, y, w, h, attr, None, text, next, None)
476
477    def radiogroup(self, name, x, y, w, h, attr, prop, text, next):
478        add_data(self.db, "Control",
479                 [(self.name, name, "RadioButtonGroup",
480                   x, y, w, h, attr, prop, text, next, None)])
481        return RadioButtonGroup(self, name, prop)
482
483    def checkbox(self, name, x, y, w, h, attr, prop, text, next):
484        return self.control(name, "CheckBox", x, y, w, h, attr, prop, text, next, None)
485