Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce cms python types in configurations #41361

Merged
merged 7 commits into from
May 15, 2023
Merged
1 change: 0 additions & 1 deletion FWCore/Framework/test/run_wrongOptionsType.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ function die { echo $1: status $2 ; echo === Log file === ; cat ${3:-/dev/null}

cmsRun ${SCRAM_TEST_PATH}/test_wrongOptionsType_cfg.py -- --name=${NAME} --value="$VALUE" > ${NAME}.log 2>&1 && die "cmsRun for ${NAME} succeeded, should have failed" 1 ${NAME}.log
grep -E "(The type in the configuration is incorrect)|(ValueError type of .* is expected to be .* but declared as)" ${NAME}.log > /dev/null || die "cmsRun for ${NAME} failed for other reason than incorrect configuration type" $? ${NAME}.log

3 changes: 2 additions & 1 deletion FWCore/Framework/test/test_wrongOptionsType_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
process.source = cms.Source("EmptySource")

process.maxEvents.input = 2

#avoid type check in python to force check in C++
delattr(process.options, args.name)
setattr(process.options, args.name, eval(str(args.value)))
17 changes: 14 additions & 3 deletions FWCore/ParameterSet/python/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def __init__(self,name,*Mods):
self.__isStrict = False
self.__dict__['_Process__modifiers'] = Mods
self.__dict__['_Process__accelerators'] = {}
self.options = Process.defaultOptions_()
self.maxEvents = Process.defaultMaxEvents_()
self.__injectValidValue('options', Process.defaultOptions_())
self.__injectValidValue('maxEvents', Process.defaultMaxEvents_())
self.maxLuminosityBlocks = Process.defaultMaxLuminosityBlocks_()
# intentionally not cloned to ensure that everyone taking
# MessageLogger still via
Expand Down Expand Up @@ -557,6 +557,10 @@ def __setattr__(self,name,value):
self._replaceInScheduleDirectly(name, newValue)

self._delattrFromSetattr(name)
self.__injectValidValue(name, value, newValue)
def __injectValidValue(self, name, value, newValue = None):
if newValue is None:
newValue = value
self.__dict__[name]=newValue
if isinstance(newValue,_Labelable):
self.__setObjectLabel(newValue, name)
Expand Down Expand Up @@ -3600,6 +3604,10 @@ def testOptions(self):
numberOfThreads =2)
self.assertEqual(p.options.numberOfThreads.value(),2)
self.assertEqual(p.options.numberOfStreams.value(),2)
del p.options
self.assertRaises(TypeError, setattr, p, 'options', untracked.PSet(numberOfThreads = int32(-1)))
p.options = untracked.PSet(numberOfThreads = untracked.uint32(4))
self.assertEqual(p.options.numberOfThreads.value(), 4)

def testMaxEvents(self):
p = Process("Test")
Expand All @@ -3614,7 +3622,10 @@ def testMaxEvents(self):
p = Process("Test")
p.maxEvents = untracked.PSet(input = untracked.int32(5))
self.assertEqual(p.maxEvents.input.value(), 5)

del p.maxEvents
self.assertRaises(TypeError, setattr, p, 'maxEvents', untracked.PSet(input = untracked.uint32(1)))
p.maxEvents = untracked.PSet(input = untracked.int32(1))
self.assertEqual(p.maxEvents.input.value(), 1)

def testExamples(self):
p = Process("Test")
Expand Down
9 changes: 7 additions & 2 deletions FWCore/ParameterSet/python/MassReplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,12 @@ def testMassSearchReplaceAnyInputTag(self):
),
)
p.op = cms.EDProducer("op", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
p.op2 = cms.EDProducer("op2", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
p.op.src="b"
p.op.vsrc=cms.VInputTag("b")
p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op)
p.op.vsrc = ["b"]
p.op2.src=cms.InputTag("b")
p.op2.vsrc = cms.VInputTag("b")
p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op*p.op2)
massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
self.assertNotEqual(cms.InputTag("new"), p.b.src)
self.assertEqual(cms.InputTag("new"), p.c.src)
Expand Down Expand Up @@ -210,6 +213,8 @@ def testMassSearchReplaceAnyInputTag(self):
self.assertEqual(cms.untracked.InputTag("new"), p.sp.test2.nested.usrc)
self.assertEqual(cms.InputTag("new"), p.op.src)
self.assertEqual(cms.InputTag("new"), p.op.vsrc[0])
self.assertEqual(cms.InputTag("new"), p.op2.src)
self.assertEqual(cms.InputTag("new"), p.op2.vsrc[0])

def testMassReplaceInputTag(self):
process1 = cms.Process("test")
Expand Down
31 changes: 26 additions & 5 deletions FWCore/ParameterSet/python/Mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def setIsFrozen(self):
self._isFrozen = True
def isCompatibleCMSType(self,aType):
return isinstance(self,aType)
def _checkAndReturnValueWithType(self, valueWithType):
if isinstance(valueWithType, type(self)):
return valueWithType
raise TypeError("Attempted to assign type {from_} to type {to}".format(from_ = str(type(valueWithType)), to = str(type(self))) )


class _SimpleParameterTypeBase(_ParameterTypeBase):
"""base class for parameter classes which only hold a single value"""
Expand Down Expand Up @@ -277,7 +282,7 @@ def __setattr__(self,name,value):
else:
# handle the case where users just replace with a value, a = 12, rather than a = cms.int32(12)
if isinstance(value,_ParameterTypeBase):
self.__dict__[name] = value
self.__dict__[name] = self.__dict__[name]._checkAndReturnValueWithType(value)
else:
self.__dict__[name].setValue(value)
self._isModified = True
Expand Down Expand Up @@ -579,7 +584,7 @@ def __init__(self,*arg,**args):
super(_ValidatingListBase,self).__init__(arg)
if 0 != len(args):
raise SyntaxError("named arguments ("+','.join([x for x in args])+") passsed to "+str(type(self)))
if not self._isValid(iter(self)):
if not type(self)._isValid(iter(self)):
raise TypeError("wrong types ("+','.join([str(type(value)) for value in iter(self)])+
") added to "+str(type(self)))
def __setitem__(self,key,value):
Expand All @@ -590,12 +595,13 @@ def __setitem__(self,key,value):
if not self._itemIsValid(value):
raise TypeError("can not insert the type "+str(type(value))+" in container "+self._labelIfAny())
super(_ValidatingListBase,self).__setitem__(key,value)
def _isValid(self,seq):
@classmethod
def _isValid(cls,seq):
# see if strings get reinterpreted as lists
if isinstance(seq, str):
return False
for item in seq:
if not self._itemIsValid(item):
if not cls._itemIsValid(item):
return False
return True
def _itemFromArgument(self, x):
Expand Down Expand Up @@ -753,7 +759,8 @@ def _modifyParametersFromDict(params, newParams, errorRaiser, keyDepth=""):

import unittest
class TestList(_ValidatingParameterListBase):
def _itemIsValid(self,item):
@classmethod
def _itemIsValid(cls,item):
return True
class testMixins(unittest.TestCase):
def testListConstruction(self):
Expand Down Expand Up @@ -938,5 +945,19 @@ def testSpecialImportRegistry(self):
self.assertEqual(reg.getSpecialImports(), ["import foo"])
reg.registerUse("a")
self.assertEqual(reg.getSpecialImports(), ["import bar", "import foo"])
def testInvalidTypeChange(self):
class __Test(_TypedParameterizable):
pass
class __TestTypeA(_SimpleParameterTypeBase):
def _isValid(self,value):
return True
class __TestTypeB(_SimpleParameterTypeBase):
def _isValid(self,value):
return True
pass
a = __Test("MyType",
t=__TestTypeA(1))
self.assertRaises(TypeError, lambda : setattr(a,'t',__TestTypeB(2)))


unittest.main()
Loading