1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for flags.FlagValues class."""
16
17import collections
18import copy
19import pickle
20import types
21from unittest import mock
22
23from absl import logging
24from absl.flags import _defines
25from absl.flags import _exceptions
26from absl.flags import _flagvalues
27from absl.flags import _helpers
28from absl.flags import _validators
29from absl.flags.tests import module_foo
30from absl.testing import absltest
31from absl.testing import parameterized
32
33
34class FlagValuesTest(absltest.TestCase):
35
36  def test_bool_flags(self):
37    for arg, expected in (('--nothing', True),
38                          ('--nothing=true', True),
39                          ('--nothing=false', False),
40                          ('--nonothing', False)):
41      fv = _flagvalues.FlagValues()
42      _defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
43      fv(('./program', arg))
44      self.assertIs(expected, fv.nothing)
45
46    for arg in ('--nonothing=true', '--nonothing=false'):
47      fv = _flagvalues.FlagValues()
48      _defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
49      with self.assertRaises(ValueError):
50        fv(('./program', arg))
51
52  def test_boolean_flag_parser_gets_string_argument(self):
53    for arg, expected in (('--nothing', 'true'),
54                          ('--nothing=true', 'true'),
55                          ('--nothing=false', 'false'),
56                          ('--nonothing', 'false')):
57      fv = _flagvalues.FlagValues()
58      _defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
59      with mock.patch.object(fv['nothing'].parser, 'parse') as mock_parse:
60        fv(('./program', arg))
61        mock_parse.assert_called_once_with(expected)
62
63  def test_unregistered_flags_are_cleaned_up(self):
64    fv = _flagvalues.FlagValues()
65    module, module_name = _helpers.get_calling_module_object_and_name()
66
67    # Define first flag.
68    _defines.DEFINE_integer('cores', 4, '', flag_values=fv, short_name='c')
69    old_cores_flag = fv['cores']
70    fv.register_key_flag_for_module(module_name, old_cores_flag)
71    self.assertEqual(fv.flags_by_module_dict(),
72                     {module_name: [old_cores_flag]})
73    self.assertEqual(fv.flags_by_module_id_dict(),
74                     {id(module): [old_cores_flag]})
75    self.assertEqual(fv.key_flags_by_module_dict(),
76                     {module_name: [old_cores_flag]})
77
78    # Redefine the same flag.
79    _defines.DEFINE_integer(
80        'cores', 4, '', flag_values=fv, short_name='c', allow_override=True)
81    new_cores_flag = fv['cores']
82    self.assertNotEqual(old_cores_flag, new_cores_flag)
83    self.assertEqual(fv.flags_by_module_dict(),
84                     {module_name: [new_cores_flag]})
85    self.assertEqual(fv.flags_by_module_id_dict(),
86                     {id(module): [new_cores_flag]})
87    # old_cores_flag is removed from key flags, and the new_cores_flag is
88    # not automatically added because it must be registered explicitly.
89    self.assertEqual(fv.key_flags_by_module_dict(), {module_name: []})
90
91    # Define a new flag but with the same short_name.
92    _defines.DEFINE_integer(
93        'changelist',
94        0,
95        '',
96        flag_values=fv,
97        short_name='c',
98        allow_override=True)
99    old_changelist_flag = fv['changelist']
100    fv.register_key_flag_for_module(module_name, old_changelist_flag)
101    # The short named flag -c is overridden to be the old_changelist_flag.
102    self.assertEqual(fv['c'], old_changelist_flag)
103    self.assertNotEqual(fv['c'], new_cores_flag)
104    self.assertEqual(fv.flags_by_module_dict(),
105                     {module_name: [new_cores_flag, old_changelist_flag]})
106    self.assertEqual(fv.flags_by_module_id_dict(),
107                     {id(module): [new_cores_flag, old_changelist_flag]})
108    self.assertEqual(fv.key_flags_by_module_dict(),
109                     {module_name: [old_changelist_flag]})
110
111    # Define a flag only with the same long name.
112    _defines.DEFINE_integer(
113        'changelist',
114        0,
115        '',
116        flag_values=fv,
117        short_name='l',
118        allow_override=True)
119    new_changelist_flag = fv['changelist']
120    self.assertNotEqual(old_changelist_flag, new_changelist_flag)
121    self.assertEqual(fv.flags_by_module_dict(),
122                     {module_name: [new_cores_flag,
123                                    old_changelist_flag,
124                                    new_changelist_flag]})
125    self.assertEqual(fv.flags_by_module_id_dict(),
126                     {id(module): [new_cores_flag,
127                                   old_changelist_flag,
128                                   new_changelist_flag]})
129    self.assertEqual(fv.key_flags_by_module_dict(),
130                     {module_name: [old_changelist_flag]})
131
132    # Delete the new changelist's long name, it should still be registered
133    # because of its short name.
134    del fv.changelist
135    self.assertNotIn('changelist', fv)
136    self.assertEqual(fv.flags_by_module_dict(),
137                     {module_name: [new_cores_flag,
138                                    old_changelist_flag,
139                                    new_changelist_flag]})
140    self.assertEqual(fv.flags_by_module_id_dict(),
141                     {id(module): [new_cores_flag,
142                                   old_changelist_flag,
143                                   new_changelist_flag]})
144    self.assertEqual(fv.key_flags_by_module_dict(),
145                     {module_name: [old_changelist_flag]})
146
147    # Delete the new changelist's short name, it should be removed.
148    del fv.l
149    self.assertNotIn('l', fv)
150    self.assertEqual(fv.flags_by_module_dict(),
151                     {module_name: [new_cores_flag,
152                                    old_changelist_flag]})
153    self.assertEqual(fv.flags_by_module_id_dict(),
154                     {id(module): [new_cores_flag,
155                                   old_changelist_flag]})
156    self.assertEqual(fv.key_flags_by_module_dict(),
157                     {module_name: [old_changelist_flag]})
158
159  def _test_find_module_or_id_defining_flag(self, test_id):
160    """Tests for find_module_defining_flag and find_module_id_defining_flag.
161
162    Args:
163      test_id: True to test find_module_id_defining_flag, False to test
164          find_module_defining_flag.
165    """
166    fv = _flagvalues.FlagValues()
167    current_module, current_module_name = (
168        _helpers.get_calling_module_object_and_name())
169    alt_module_name = _flagvalues.__name__
170
171    if test_id:
172      current_module_or_id = id(current_module)
173      alt_module_or_id = id(_flagvalues)
174      testing_fn = fv.find_module_id_defining_flag
175    else:
176      current_module_or_id = current_module_name
177      alt_module_or_id = alt_module_name
178      testing_fn = fv.find_module_defining_flag
179
180    # Define first flag.
181    _defines.DEFINE_integer('cores', 4, '', flag_values=fv, short_name='c')
182    module_or_id_cores = testing_fn('cores')
183    self.assertEqual(module_or_id_cores, current_module_or_id)
184    module_or_id_c = testing_fn('c')
185    self.assertEqual(module_or_id_c, current_module_or_id)
186
187    # Redefine the same flag in another module.
188    _defines.DEFINE_integer(
189        'cores',
190        4,
191        '',
192        flag_values=fv,
193        module_name=alt_module_name,
194        short_name='c',
195        allow_override=True)
196    module_or_id_cores = testing_fn('cores')
197    self.assertEqual(module_or_id_cores, alt_module_or_id)
198    module_or_id_c = testing_fn('c')
199    self.assertEqual(module_or_id_c, alt_module_or_id)
200
201    # Define a new flag but with the same short_name.
202    _defines.DEFINE_integer(
203        'changelist',
204        0,
205        '',
206        flag_values=fv,
207        short_name='c',
208        allow_override=True)
209    module_or_id_cores = testing_fn('cores')
210    self.assertEqual(module_or_id_cores, alt_module_or_id)
211    module_or_id_changelist = testing_fn('changelist')
212    self.assertEqual(module_or_id_changelist, current_module_or_id)
213    module_or_id_c = testing_fn('c')
214    self.assertEqual(module_or_id_c, current_module_or_id)
215
216    # Define a flag in another module only with the same long name.
217    _defines.DEFINE_integer(
218        'changelist',
219        0,
220        '',
221        flag_values=fv,
222        module_name=alt_module_name,
223        short_name='l',
224        allow_override=True)
225    module_or_id_cores = testing_fn('cores')
226    self.assertEqual(module_or_id_cores, alt_module_or_id)
227    module_or_id_changelist = testing_fn('changelist')
228    self.assertEqual(module_or_id_changelist, alt_module_or_id)
229    module_or_id_c = testing_fn('c')
230    self.assertEqual(module_or_id_c, current_module_or_id)
231    module_or_id_l = testing_fn('l')
232    self.assertEqual(module_or_id_l, alt_module_or_id)
233
234    # Delete the changelist flag, its short name should still be registered.
235    del fv.changelist
236    module_or_id_changelist = testing_fn('changelist')
237    self.assertIsNone(module_or_id_changelist)
238    module_or_id_c = testing_fn('c')
239    self.assertEqual(module_or_id_c, current_module_or_id)
240    module_or_id_l = testing_fn('l')
241    self.assertEqual(module_or_id_l, alt_module_or_id)
242
243  def test_find_module_defining_flag(self):
244    self._test_find_module_or_id_defining_flag(test_id=False)
245
246  def test_find_module_id_defining_flag(self):
247    self._test_find_module_or_id_defining_flag(test_id=True)
248
249  def test_set_default(self):
250    fv = _flagvalues.FlagValues()
251    fv.mark_as_parsed()
252    with self.assertRaises(_exceptions.UnrecognizedFlagError):
253      fv.set_default('changelist', 1)
254    _defines.DEFINE_integer('changelist', 0, 'help', flag_values=fv)
255    self.assertEqual(0, fv.changelist)
256    fv.set_default('changelist', 2)
257    self.assertEqual(2, fv.changelist)
258
259  def test_default_gnu_getopt_value(self):
260    self.assertTrue(_flagvalues.FlagValues().is_gnu_getopt())
261
262  def test_known_only_flags_in_gnustyle(self):
263
264    def run_test(argv, defined_py_flags, expected_argv):
265      fv = _flagvalues.FlagValues()
266      fv.set_gnu_getopt(True)
267      for f in defined_py_flags:
268        if f.startswith('b'):
269          _defines.DEFINE_boolean(f, False, 'help', flag_values=fv)
270        else:
271          _defines.DEFINE_string(f, 'default', 'help', flag_values=fv)
272      output_argv = fv(argv, known_only=True)
273      self.assertEqual(expected_argv, output_argv)
274
275    run_test(
276        argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
277        defined_py_flags=[],
278        expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '))
279    run_test(
280        argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
281        defined_py_flags=['f1'],
282        expected_argv='0 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '))
283    run_test(
284        argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
285        defined_py_flags=['f2'],
286        expected_argv='0 --f1=v1 cmd --b1 --f3 v3 --nob2'.split(' '))
287    run_test(
288        argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
289        defined_py_flags=['b1'],
290        expected_argv='0 --f1=v1 cmd --f2 v2 --f3 v3 --nob2'.split(' '))
291    run_test(
292        argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
293        defined_py_flags=['f3'],
294        expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --nob2'.split(' '))
295    run_test(
296        argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
297        defined_py_flags=['b2'],
298        expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3'.split(' '))
299    run_test(
300        argv=('0 --f1=v1 cmd --undefok=f1 --f2 v2 --b1 '
301              '--f3 v3 --nob2').split(' '),
302        defined_py_flags=['b2'],
303        expected_argv='0 cmd --f2 v2 --b1 --f3 v3'.split(' '))
304    run_test(
305        argv=('0 --f1=v1 cmd --undefok f1,f2 --f2 v2 --b1 '
306              '--f3 v3 --nob2').split(' '),
307        defined_py_flags=['b2'],
308        # Note v2 is preserved here, since undefok requires the flag being
309        # specified in the form of --flag=value.
310        expected_argv='0 cmd v2 --b1 --f3 v3'.split(' '))
311
312  def test_invalid_flag_name(self):
313    with self.assertRaises(_exceptions.Error):
314      _defines.DEFINE_boolean('test ', 0, '')
315
316    with self.assertRaises(_exceptions.Error):
317      _defines.DEFINE_boolean(' test', 0, '')
318
319    with self.assertRaises(_exceptions.Error):
320      _defines.DEFINE_boolean('te st', 0, '')
321
322    with self.assertRaises(_exceptions.Error):
323      _defines.DEFINE_boolean('', 0, '')
324
325    with self.assertRaises(_exceptions.Error):
326      _defines.DEFINE_boolean(1, 0, '')  # type: ignore
327
328  def test_len(self):
329    fv = _flagvalues.FlagValues()
330    self.assertEmpty(fv)
331    self.assertFalse(fv)
332
333    _defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv)
334    self.assertLen(fv, 1)
335    self.assertTrue(fv)
336
337    _defines.DEFINE_boolean(
338        'bool', False, 'help', short_name='b', flag_values=fv)
339    self.assertLen(fv, 3)
340    self.assertTrue(fv)
341
342  def test_pickle(self):
343    fv = _flagvalues.FlagValues()
344    with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"):
345      pickle.dumps(fv)
346
347  def test_copy(self):
348    fv = _flagvalues.FlagValues()
349    _defines.DEFINE_integer('answer', 0, 'help', flag_values=fv)
350    fv(['', '--answer=1'])
351
352    with self.assertRaisesRegex(TypeError,
353                                'FlagValues does not support shallow copies'):
354      copy.copy(fv)
355
356    fv2 = copy.deepcopy(fv)
357    self.assertEqual(fv2.answer, 1)
358
359    fv2.answer = 42
360    self.assertEqual(fv2.answer, 42)
361    self.assertEqual(fv.answer, 1)
362
363  def test_conflicting_flags(self):
364    fv = _flagvalues.FlagValues()
365    with self.assertRaises(_exceptions.FlagNameConflictsWithMethodError):
366      _defines.DEFINE_boolean('is_gnu_getopt', False, 'help', flag_values=fv)
367    _defines.DEFINE_boolean(
368        'is_gnu_getopt',
369        False,
370        'help',
371        flag_values=fv,
372        allow_using_method_names=True)
373    self.assertFalse(fv['is_gnu_getopt'].value)
374    self.assertIsInstance(fv.is_gnu_getopt, types.MethodType)
375
376  def test_get_flags_for_module(self):
377    fv = _flagvalues.FlagValues()
378    _defines.DEFINE_string('foo', None, 'help', flag_values=fv)
379    module_foo.define_flags(fv)
380    flags = fv.get_flags_for_module('__main__')
381
382    self.assertEqual({'foo'}, {flag.name for flag in flags})
383
384    flags = fv.get_flags_for_module(module_foo)
385    self.assertEqual({'tmod_foo_bool', 'tmod_foo_int', 'tmod_foo_str'},
386                     {flag.name for flag in flags})
387
388  def test_get_help(self):
389    fv = _flagvalues.FlagValues()
390    self.assertMultiLineEqual('''\
391--flagfile: Insert flag definitions from the given file into the command line.
392  (default: '')
393--undefok: comma-separated list of flag names that it is okay to specify on the
394  command line even if the program does not define a flag with that name.
395  IMPORTANT: flags in this list that have arguments MUST use the --flag=value
396  format.
397  (default: '')''', fv.get_help())
398
399    module_foo.define_flags(fv)
400    self.assertMultiLineEqual('''
401absl.flags.tests.module_bar:
402  --tmod_bar_t: Sample int flag.
403    (default: '4')
404    (an integer)
405  --tmod_bar_u: Sample int flag.
406    (default: '5')
407    (an integer)
408  --tmod_bar_v: Sample int flag.
409    (default: '6')
410    (an integer)
411  --[no]tmod_bar_x: Boolean flag.
412    (default: 'true')
413  --tmod_bar_y: String flag.
414    (default: 'default')
415  --[no]tmod_bar_z: Another boolean flag from module bar.
416    (default: 'false')
417
418absl.flags.tests.module_foo:
419  --[no]tmod_foo_bool: Boolean flag from module foo.
420    (default: 'true')
421  --tmod_foo_int: Sample int flag.
422    (default: '3')
423    (an integer)
424  --tmod_foo_str: String flag.
425    (default: 'default')
426
427absl.flags:
428  --flagfile: Insert flag definitions from the given file into the command line.
429    (default: '')
430  --undefok: comma-separated list of flag names that it is okay to specify on
431    the command line even if the program does not define a flag with that name.
432    IMPORTANT: flags in this list that have arguments MUST use the --flag=value
433    format.
434    (default: '')''', fv.get_help())
435
436    self.assertMultiLineEqual('''
437xxxxabsl.flags.tests.module_bar:
438xxxx  --tmod_bar_t: Sample int flag.
439xxxx    (default: '4')
440xxxx    (an integer)
441xxxx  --tmod_bar_u: Sample int flag.
442xxxx    (default: '5')
443xxxx    (an integer)
444xxxx  --tmod_bar_v: Sample int flag.
445xxxx    (default: '6')
446xxxx    (an integer)
447xxxx  --[no]tmod_bar_x: Boolean flag.
448xxxx    (default: 'true')
449xxxx  --tmod_bar_y: String flag.
450xxxx    (default: 'default')
451xxxx  --[no]tmod_bar_z: Another boolean flag from module bar.
452xxxx    (default: 'false')
453
454xxxxabsl.flags.tests.module_foo:
455xxxx  --[no]tmod_foo_bool: Boolean flag from module foo.
456xxxx    (default: 'true')
457xxxx  --tmod_foo_int: Sample int flag.
458xxxx    (default: '3')
459xxxx    (an integer)
460xxxx  --tmod_foo_str: String flag.
461xxxx    (default: 'default')
462
463xxxxabsl.flags:
464xxxx  --flagfile: Insert flag definitions from the given file into the command
465xxxx    line.
466xxxx    (default: '')
467xxxx  --undefok: comma-separated list of flag names that it is okay to specify
468xxxx    on the command line even if the program does not define a flag with that
469xxxx    name.  IMPORTANT: flags in this list that have arguments MUST use the
470xxxx    --flag=value format.
471xxxx    (default: '')''', fv.get_help(prefix='xxxx'))
472
473    self.assertMultiLineEqual('''
474absl.flags.tests.module_bar:
475  --tmod_bar_t: Sample int flag.
476    (default: '4')
477    (an integer)
478  --tmod_bar_u: Sample int flag.
479    (default: '5')
480    (an integer)
481  --tmod_bar_v: Sample int flag.
482    (default: '6')
483    (an integer)
484  --[no]tmod_bar_x: Boolean flag.
485    (default: 'true')
486  --tmod_bar_y: String flag.
487    (default: 'default')
488  --[no]tmod_bar_z: Another boolean flag from module bar.
489    (default: 'false')
490
491absl.flags.tests.module_foo:
492  --[no]tmod_foo_bool: Boolean flag from module foo.
493    (default: 'true')
494  --tmod_foo_int: Sample int flag.
495    (default: '3')
496    (an integer)
497  --tmod_foo_str: String flag.
498    (default: 'default')''', fv.get_help(include_special_flags=False))
499
500  def test_str(self):
501    fv = _flagvalues.FlagValues()
502    self.assertEqual(str(fv), fv.get_help())
503    module_foo.define_flags(fv)
504    self.assertEqual(str(fv), fv.get_help())
505
506  def test_empty_argv(self):
507    fv = _flagvalues.FlagValues()
508    with self.assertRaises(ValueError):
509      fv([])
510
511  def test_invalid_argv(self):
512    fv = _flagvalues.FlagValues()
513    with self.assertRaises(TypeError):
514      fv('./program')  # type: ignore
515    with self.assertRaises(TypeError):
516      fv(b'./program')  # type: ignore
517
518  def test_flags_dir(self):
519    flag_values = _flagvalues.FlagValues()
520    flag_name1 = 'bool_flag'
521    flag_name2 = 'string_flag'
522    flag_name3 = 'float_flag'
523    description = 'Description'
524    _defines.DEFINE_boolean(
525        flag_name1, None, description, flag_values=flag_values)
526    _defines.DEFINE_string(
527        flag_name2, None, description, flag_values=flag_values)
528    self.assertEqual(sorted([flag_name1, flag_name2]), dir(flag_values))
529
530    _defines.DEFINE_float(
531        flag_name3, None, description, flag_values=flag_values)
532    self.assertEqual(
533        sorted([flag_name1, flag_name2, flag_name3]), dir(flag_values))
534
535  def test_flags_into_string_deterministic(self):
536    flag_values = _flagvalues.FlagValues()
537    _defines.DEFINE_string(
538        'fa', 'x', '', flag_values=flag_values, module_name='mb')
539    _defines.DEFINE_string(
540        'fb', 'x', '', flag_values=flag_values, module_name='mb')
541    _defines.DEFINE_string(
542        'fc', 'x', '', flag_values=flag_values, module_name='ma')
543    _defines.DEFINE_string(
544        'fd', 'x', '', flag_values=flag_values, module_name='ma')
545
546    expected = ('--fc=x\n'
547                '--fd=x\n'
548                '--fa=x\n'
549                '--fb=x\n')
550
551    flags_by_module_items = sorted(
552        flag_values.flags_by_module_dict().items(), reverse=True)
553    for _, module_flags in flags_by_module_items:
554      module_flags.sort(reverse=True)
555
556    flag_values.__dict__['__flags_by_module'] = collections.OrderedDict(
557        flags_by_module_items)
558
559    actual = flag_values.flags_into_string()
560    self.assertEqual(expected, actual)
561
562  def test_validate_all_flags(self):
563    fv = _flagvalues.FlagValues()
564    _defines.DEFINE_string('name', None, '', flag_values=fv)
565    _validators.mark_flag_as_required('name', flag_values=fv)
566    with self.assertRaises(_exceptions.IllegalFlagValueError):
567      fv.validate_all_flags()
568    fv.name = 'test'
569    fv.validate_all_flags()
570
571
572class FlagValuesLoggingTest(absltest.TestCase):
573  """Test to make sure logging.* functions won't recurse.
574
575  Logging may and does happen before flags initialization. We need to make
576  sure that any warnings trown by flagvalues do not result in unlimited
577  recursion.
578  """
579
580  def test_logging_do_not_recurse(self):
581    logging.info('test info')
582    try:
583      raise ValueError('test exception')
584    except ValueError:
585      logging.exception('test message')
586
587
588class FlagSubstrMatchingTests(parameterized.TestCase):
589  """Tests related to flag substring matching."""
590
591  def _get_test_flag_values(self):
592    """Get a _flagvalues.FlagValues() instance, set up for tests."""
593    flag_values = _flagvalues.FlagValues()
594
595    _defines.DEFINE_string('strf', '', '', flag_values=flag_values)
596    _defines.DEFINE_boolean('boolf', 0, '', flag_values=flag_values)
597
598    return flag_values
599
600  # Test cases that should always make parsing raise an error.
601  # Tuples of strings with the argv to use.
602  FAIL_TEST_CASES = [
603      ('./program', '--boo', '0'),
604      ('./program', '--boo=true', '0'),
605      ('./program', '--boo=0'),
606      ('./program', '--noboo'),
607      ('./program', '--st=blah'),
608      ('./program', '--st=de'),
609      ('./program', '--st=blah', '--boo'),
610      ('./program', '--st=blah', 'unused'),
611      ('./program', '--st=--blah'),
612      ('./program', '--st', '--blah'),
613  ]
614
615  @parameterized.parameters(FAIL_TEST_CASES)
616  def test_raise(self, *argv):
617    """Test that raising works."""
618    fv = self._get_test_flag_values()
619    with self.assertRaises(_exceptions.UnrecognizedFlagError):
620      fv(argv)
621
622  @parameterized.parameters(
623      FAIL_TEST_CASES + [('./program', 'unused', '--st=blah')])
624  def test_gnu_getopt_raise(self, *argv):
625    """Test that raising works when combined with GNU-style getopt."""
626    fv = self._get_test_flag_values()
627    fv.set_gnu_getopt()
628    with self.assertRaises(_exceptions.UnrecognizedFlagError):
629      fv(argv)
630
631
632class SettingUnknownFlagTest(absltest.TestCase):
633
634  def setUp(self):
635    super(SettingUnknownFlagTest, self).setUp()
636    self.setter_called = 0
637
638  def set_undef(self, unused_name, unused_val):
639    self.setter_called += 1
640
641  def test_raise_on_undefined(self):
642    new_flags = _flagvalues.FlagValues()
643    with self.assertRaises(_exceptions.UnrecognizedFlagError):
644      new_flags.undefined_flag = 0
645
646  def test_not_raise(self):
647    new_flags = _flagvalues.FlagValues()
648    new_flags._register_unknown_flag_setter(self.set_undef)
649    new_flags.undefined_flag = 0
650    self.assertEqual(self.setter_called, 1)
651
652  def test_not_raise_on_undefined_if_undefok(self):
653    new_flags = _flagvalues.FlagValues()
654    args = ['0', '--foo', '--bar=1', '--undefok=foo,bar']
655    unparsed = new_flags(args, known_only=True)
656    self.assertEqual(['0'], unparsed)
657
658  def test_re_raise_undefined(self):
659    def setter(unused_name, unused_val):
660      raise NameError()
661    new_flags = _flagvalues.FlagValues()
662    new_flags._register_unknown_flag_setter(setter)
663    with self.assertRaises(_exceptions.UnrecognizedFlagError):
664      new_flags.undefined_flag = 0
665
666  def test_re_raise_invalid(self):
667    def setter(unused_name, unused_val):
668      raise ValueError()
669    new_flags = _flagvalues.FlagValues()
670    new_flags._register_unknown_flag_setter(setter)
671    with self.assertRaises(_exceptions.IllegalFlagValueError):
672      new_flags.undefined_flag = 0
673
674
675class SetAttributesTest(absltest.TestCase):
676
677  def setUp(self):
678    super(SetAttributesTest, self).setUp()
679    self.new_flags = _flagvalues.FlagValues()
680    _defines.DEFINE_boolean(
681        'defined_flag', None, '', flag_values=self.new_flags)
682    _defines.DEFINE_boolean(
683        'another_defined_flag', None, '', flag_values=self.new_flags)
684    self.setter_called = 0
685
686  def set_undef(self, unused_name, unused_val):
687    self.setter_called += 1
688
689  def test_two_defined_flags(self):
690    self.new_flags._set_attributes(
691        defined_flag=False, another_defined_flag=False)
692    self.assertEqual(self.setter_called, 0)
693
694  def test_one_defined_one_undefined_flag(self):
695    with self.assertRaises(_exceptions.UnrecognizedFlagError):
696      self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
697
698  def test_register_unknown_flag_setter(self):
699    self.new_flags._register_unknown_flag_setter(self.set_undef)
700    self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
701    self.assertEqual(self.setter_called, 1)
702
703
704class FlagsDashSyntaxTest(absltest.TestCase):
705
706  def setUp(self):
707    super(FlagsDashSyntaxTest, self).setUp()
708    self.fv = _flagvalues.FlagValues()
709    _defines.DEFINE_string(
710        'long_name', 'default', 'help', flag_values=self.fv, short_name='s')
711
712  def test_long_name_one_dash(self):
713    self.fv(['./program', '-long_name=new'])
714    self.assertEqual('new', self.fv.long_name)
715
716  def test_long_name_two_dashes(self):
717    self.fv(['./program', '--long_name=new'])
718    self.assertEqual('new', self.fv.long_name)
719
720  def test_long_name_three_dashes(self):
721    with self.assertRaises(_exceptions.UnrecognizedFlagError):
722      self.fv(['./program', '---long_name=new'])
723
724  def test_short_name_one_dash(self):
725    self.fv(['./program', '-s=new'])
726    self.assertEqual('new', self.fv.s)
727
728  def test_short_name_two_dashes(self):
729    self.fv(['./program', '--s=new'])
730    self.assertEqual('new', self.fv.s)
731
732  def test_short_name_three_dashes(self):
733    with self.assertRaises(_exceptions.UnrecognizedFlagError):
734      self.fv(['./program', '---s=new'])
735
736
737class UnparseFlagsTest(absltest.TestCase):
738
739  def test_using_default_value_none(self):
740    fv = _flagvalues.FlagValues()
741    _defines.DEFINE_string('default_none', None, 'help', flag_values=fv)
742    self.assertTrue(fv['default_none'].using_default_value)
743    fv(['', '--default_none=notNone'])
744    self.assertFalse(fv['default_none'].using_default_value)
745    fv.unparse_flags()
746    self.assertTrue(fv['default_none'].using_default_value)
747    fv(['', '--default_none=alsoNotNone'])
748    self.assertFalse(fv['default_none'].using_default_value)
749    fv.unparse_flags()
750    self.assertTrue(fv['default_none'].using_default_value)
751
752  def test_using_default_value_not_none(self):
753    fv = _flagvalues.FlagValues()
754    _defines.DEFINE_string('default_foo', 'foo', 'help', flag_values=fv)
755
756    fv.mark_as_parsed()
757    self.assertTrue(fv['default_foo'].using_default_value)
758
759    fv(['', '--default_foo=foo'])
760    self.assertFalse(fv['default_foo'].using_default_value)
761
762    fv(['', '--default_foo=notFoo'])
763    self.assertFalse(fv['default_foo'].using_default_value)
764
765    fv.unparse_flags()
766    self.assertTrue(fv['default_foo'].using_default_value)
767
768    fv(['', '--default_foo=alsoNotFoo'])
769    self.assertFalse(fv['default_foo'].using_default_value)
770
771  def test_allow_overwrite_false(self):
772    fv = _flagvalues.FlagValues()
773    _defines.DEFINE_string(
774        'default_none', None, 'help', allow_overwrite=False, flag_values=fv)
775    _defines.DEFINE_string(
776        'default_foo', 'foo', 'help', allow_overwrite=False, flag_values=fv)
777
778    fv.mark_as_parsed()
779    self.assertEqual('foo', fv.default_foo)
780    self.assertIsNone(fv.default_none)
781
782    fv(['', '--default_foo=notFoo', '--default_none=notNone'])
783    self.assertEqual('notFoo', fv.default_foo)
784    self.assertEqual('notNone', fv.default_none)
785
786    fv.unparse_flags()
787    self.assertEqual('foo', fv['default_foo'].value)
788    self.assertIsNone(fv['default_none'].value)
789
790    fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone'])
791    self.assertEqual('alsoNotFoo', fv.default_foo)
792    self.assertEqual('alsoNotNone', fv.default_none)
793
794  def test_multi_string_default_none(self):
795    fv = _flagvalues.FlagValues()
796    _defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv)
797    fv.mark_as_parsed()
798    self.assertIsNone(fv.foo)
799    fv(['', '--foo=aa'])
800    self.assertEqual(['aa'], fv.foo)
801    fv.unparse_flags()
802    self.assertIsNone(fv['foo'].value)
803    fv(['', '--foo=bb', '--foo=cc'])
804    self.assertEqual(['bb', 'cc'], fv.foo)
805    fv.unparse_flags()
806    self.assertIsNone(fv['foo'].value)
807
808  def test_multi_string_default_string(self):
809    fv = _flagvalues.FlagValues()
810    _defines.DEFINE_multi_string('foo', 'xyz', 'help', flag_values=fv)
811    expected_default = ['xyz']
812    fv.mark_as_parsed()
813    self.assertEqual(expected_default, fv.foo)
814    fv(['', '--foo=aa'])
815    self.assertEqual(['aa'], fv.foo)
816    fv.unparse_flags()
817    self.assertEqual(expected_default, fv['foo'].value)
818    fv(['', '--foo=bb', '--foo=cc'])
819    self.assertEqual(['bb', 'cc'], fv['foo'].value)
820    fv.unparse_flags()
821    self.assertEqual(expected_default, fv['foo'].value)
822
823  def test_multi_string_default_list(self):
824    fv = _flagvalues.FlagValues()
825    _defines.DEFINE_multi_string(
826        'foo', ['xx', 'yy', 'zz'], 'help', flag_values=fv)
827    expected_default = ['xx', 'yy', 'zz']
828    fv.mark_as_parsed()
829    self.assertEqual(expected_default, fv.foo)
830    fv(['', '--foo=aa'])
831    self.assertEqual(['aa'], fv.foo)
832    fv.unparse_flags()
833    self.assertEqual(expected_default, fv['foo'].value)
834    fv(['', '--foo=bb', '--foo=cc'])
835    self.assertEqual(['bb', 'cc'], fv.foo)
836    fv.unparse_flags()
837    self.assertEqual(expected_default, fv['foo'].value)
838
839
840class UnparsedFlagAccessTest(absltest.TestCase):
841
842  def test_unparsed_flag_access(self):
843    fv = _flagvalues.FlagValues()
844    _defines.DEFINE_string('name', 'default', 'help', flag_values=fv)
845    with self.assertRaises(_exceptions.UnparsedFlagAccessError):
846      _ = fv.name
847
848  def test_hasattr_raises_in_py3(self):
849    fv = _flagvalues.FlagValues()
850    _defines.DEFINE_string('name', 'default', 'help', flag_values=fv)
851    with self.assertRaises(_exceptions.UnparsedFlagAccessError):
852      _ = hasattr(fv, 'name')
853
854  def test_unparsed_flags_access_raises_after_unparse_flags(self):
855    fv = _flagvalues.FlagValues()
856    _defines.DEFINE_string('a_str', 'default_value', 'help', flag_values=fv)
857    fv.mark_as_parsed()
858    self.assertEqual(fv.a_str, 'default_value')
859    fv.unparse_flags()
860    with self.assertRaises(_exceptions.UnparsedFlagAccessError):
861      _ = fv.a_str
862
863
864class FlagHolderTest(absltest.TestCase):
865
866  def setUp(self):
867    super(FlagHolderTest, self).setUp()
868    self.fv = _flagvalues.FlagValues()
869    self.name_flag = _defines.DEFINE_string(
870        'name', 'default', 'help', flag_values=self.fv)
871
872  def parse_flags(self, *argv):
873    self.fv.unparse_flags()
874    self.fv(['binary_name'] + list(argv))
875
876  def test_name(self):
877    self.assertEqual('name', self.name_flag.name)
878
879  def test_value_before_flag_parsing(self):
880    with self.assertRaises(_exceptions.UnparsedFlagAccessError):
881      _ = self.name_flag.value
882
883  def test_value_returns_default_value_if_not_explicitly_set(self):
884    self.parse_flags()
885    self.assertEqual('default', self.name_flag.value)
886
887  def test_value_returns_explicitly_set_value(self):
888    self.parse_flags('--name=new_value')
889    self.assertEqual('new_value', self.name_flag.value)
890
891  def test_present_returns_false_before_flag_parsing(self):
892    self.assertFalse(self.name_flag.present)
893
894  def test_present_returns_false_if_not_explicitly_set(self):
895    self.parse_flags()
896    self.assertFalse(self.name_flag.present)
897
898  def test_present_returns_true_if_explicitly_set(self):
899    self.parse_flags('--name=new_value')
900    self.assertTrue(self.name_flag.present)
901
902  def test_serializes_flag(self):
903    self.parse_flags('--name=new_value')
904    self.assertEqual('--name=new_value', self.name_flag.serialize())
905
906  def test_allow_override(self):
907    first = _defines.DEFINE_integer(
908        'int_flag', 1, 'help', flag_values=self.fv, allow_override=1)
909    second = _defines.DEFINE_integer(
910        'int_flag', 2, 'help', flag_values=self.fv, allow_override=1)
911    self.parse_flags('--int_flag=3')
912    self.assertEqual(3, first.value)
913    self.assertEqual(3, second.value)
914    self.assertTrue(first.present)
915    self.assertTrue(second.present)
916
917  def test_eq(self):
918    with self.assertRaises(TypeError):
919      self.name_flag == 'value'  # pylint: disable=pointless-statement
920
921  def test_eq_reflection(self):
922    with self.assertRaises(TypeError):
923      'value' == self.name_flag  # pylint: disable=pointless-statement
924
925  def test_bool(self):
926    with self.assertRaises(TypeError):
927      bool(self.name_flag)
928
929
930if __name__ == '__main__':
931  absltest.main()
932