Skip to content

Commit

Permalink
Merge pull request #41361 from Dr15Jones/enforceCMSPythonTypes
Browse files Browse the repository at this point in the history
Enforce cms python types in configurations
  • Loading branch information
cmsbuild authored May 15, 2023
2 parents e2d3811 + e081f54 commit 7d9746a
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 50 deletions.
1 change: 0 additions & 1 deletion FWCore/Framework/test/run_wrongOptionsType.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ function die { echo $1: status $2 ; echo === Log file === ; cat ${3:-/dev/null}

cmsRun ${SCRAM_TEST_PATH}/test_wrongOptionsType_cfg.py -- --name=${NAME} --value="$VALUE" > ${NAME}.log 2>&1 && die "cmsRun for ${NAME} succeeded, should have failed" 1 ${NAME}.log
grep -E "(The type in the configuration is incorrect)|(ValueError type of .* is expected to be .* but declared as)" ${NAME}.log > /dev/null || die "cmsRun for ${NAME} failed for other reason than incorrect configuration type" $? ${NAME}.log

3 changes: 2 additions & 1 deletion FWCore/Framework/test/test_wrongOptionsType_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
process.source = cms.Source("EmptySource")

process.maxEvents.input = 2

#avoid type check in python to force check in C++
delattr(process.options, args.name)
setattr(process.options, args.name, eval(str(args.value)))
17 changes: 14 additions & 3 deletions FWCore/ParameterSet/python/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def __init__(self,name,*Mods):
self.__isStrict = False
self.__dict__['_Process__modifiers'] = Mods
self.__dict__['_Process__accelerators'] = {}
self.options = Process.defaultOptions_()
self.maxEvents = Process.defaultMaxEvents_()
self.__injectValidValue('options', Process.defaultOptions_())
self.__injectValidValue('maxEvents', Process.defaultMaxEvents_())
self.maxLuminosityBlocks = Process.defaultMaxLuminosityBlocks_()
# intentionally not cloned to ensure that everyone taking
# MessageLogger still via
Expand Down Expand Up @@ -557,6 +557,10 @@ def __setattr__(self,name,value):
self._replaceInScheduleDirectly(name, newValue)

self._delattrFromSetattr(name)
self.__injectValidValue(name, value, newValue)
def __injectValidValue(self, name, value, newValue = None):
if newValue is None:
newValue = value
self.__dict__[name]=newValue
if isinstance(newValue,_Labelable):
self.__setObjectLabel(newValue, name)
Expand Down Expand Up @@ -3600,6 +3604,10 @@ def testOptions(self):
numberOfThreads =2)
self.assertEqual(p.options.numberOfThreads.value(),2)
self.assertEqual(p.options.numberOfStreams.value(),2)
del p.options
self.assertRaises(TypeError, setattr, p, 'options', untracked.PSet(numberOfThreads = int32(-1)))
p.options = untracked.PSet(numberOfThreads = untracked.uint32(4))
self.assertEqual(p.options.numberOfThreads.value(), 4)

def testMaxEvents(self):
p = Process("Test")
Expand All @@ -3614,7 +3622,10 @@ def testMaxEvents(self):
p = Process("Test")
p.maxEvents = untracked.PSet(input = untracked.int32(5))
self.assertEqual(p.maxEvents.input.value(), 5)

del p.maxEvents
self.assertRaises(TypeError, setattr, p, 'maxEvents', untracked.PSet(input = untracked.uint32(1)))
p.maxEvents = untracked.PSet(input = untracked.int32(1))
self.assertEqual(p.maxEvents.input.value(), 1)

def testExamples(self):
p = Process("Test")
Expand Down
9 changes: 7 additions & 2 deletions FWCore/ParameterSet/python/MassReplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,12 @@ def testMassSearchReplaceAnyInputTag(self):
),
)
p.op = cms.EDProducer("op", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
p.op2 = cms.EDProducer("op2", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
p.op.src="b"
p.op.vsrc=cms.VInputTag("b")
p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op)
p.op.vsrc = ["b"]
p.op2.src=cms.InputTag("b")
p.op2.vsrc = cms.VInputTag("b")
p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op*p.op2)
massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
self.assertNotEqual(cms.InputTag("new"), p.b.src)
self.assertEqual(cms.InputTag("new"), p.c.src)
Expand Down Expand Up @@ -210,6 +213,8 @@ def testMassSearchReplaceAnyInputTag(self):
self.assertEqual(cms.untracked.InputTag("new"), p.sp.test2.nested.usrc)
self.assertEqual(cms.InputTag("new"), p.op.src)
self.assertEqual(cms.InputTag("new"), p.op.vsrc[0])
self.assertEqual(cms.InputTag("new"), p.op2.src)
self.assertEqual(cms.InputTag("new"), p.op2.vsrc[0])

def testMassReplaceInputTag(self):
process1 = cms.Process("test")
Expand Down
31 changes: 26 additions & 5 deletions FWCore/ParameterSet/python/Mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def setIsFrozen(self):
self._isFrozen = True
def isCompatibleCMSType(self,aType):
return isinstance(self,aType)
def _checkAndReturnValueWithType(self, valueWithType):
if isinstance(valueWithType, type(self)):
return valueWithType
raise TypeError("Attempted to assign type {from_} to type {to}".format(from_ = str(type(valueWithType)), to = str(type(self))) )


class _SimpleParameterTypeBase(_ParameterTypeBase):
"""base class for parameter classes which only hold a single value"""
Expand Down Expand Up @@ -277,7 +282,7 @@ def __setattr__(self,name,value):
else:
# handle the case where users just replace with a value, a = 12, rather than a = cms.int32(12)
if isinstance(value,_ParameterTypeBase):
self.__dict__[name] = value
self.__dict__[name] = self.__dict__[name]._checkAndReturnValueWithType(value)
else:
self.__dict__[name].setValue(value)
self._isModified = True
Expand Down Expand Up @@ -579,7 +584,7 @@ def __init__(self,*arg,**args):
super(_ValidatingListBase,self).__init__(arg)
if 0 != len(args):
raise SyntaxError("named arguments ("+','.join([x for x in args])+") passsed to "+str(type(self)))
if not self._isValid(iter(self)):
if not type(self)._isValid(iter(self)):
raise TypeError("wrong types ("+','.join([str(type(value)) for value in iter(self)])+
") added to "+str(type(self)))
def __setitem__(self,key,value):
Expand All @@ -590,12 +595,13 @@ def __setitem__(self,key,value):
if not self._itemIsValid(value):
raise TypeError("can not insert the type "+str(type(value))+" in container "+self._labelIfAny())
super(_ValidatingListBase,self).__setitem__(key,value)
def _isValid(self,seq):
@classmethod
def _isValid(cls,seq):
# see if strings get reinterpreted as lists
if isinstance(seq, str):
return False
for item in seq:
if not self._itemIsValid(item):
if not cls._itemIsValid(item):
return False
return True
def _itemFromArgument(self, x):
Expand Down Expand Up @@ -753,7 +759,8 @@ def _modifyParametersFromDict(params, newParams, errorRaiser, keyDepth=""):

import unittest
class TestList(_ValidatingParameterListBase):
def _itemIsValid(self,item):
@classmethod
def _itemIsValid(cls,item):
return True
class testMixins(unittest.TestCase):
def testListConstruction(self):
Expand Down Expand Up @@ -938,5 +945,19 @@ def testSpecialImportRegistry(self):
self.assertEqual(reg.getSpecialImports(), ["import foo"])
reg.registerUse("a")
self.assertEqual(reg.getSpecialImports(), ["import bar", "import foo"])
def testInvalidTypeChange(self):
class __Test(_TypedParameterizable):
pass
class __TestTypeA(_SimpleParameterTypeBase):
def _isValid(self,value):
return True
class __TestTypeB(_SimpleParameterTypeBase):
def _isValid(self,value):
return True
pass
a = __Test("MyType",
t=__TestTypeA(1))
self.assertRaises(TypeError, lambda : setattr(a,'t',__TestTypeB(2)))


unittest.main()
Loading

0 comments on commit 7d9746a

Please sign in to comment.