Skip to content

Commit

Permalink
[Poly-encoder] Fixes for DSTC7 Task (facebookresearch#2314)
Browse files Browse the repository at this point in the history
* updates to dstc7 task

* black

* update docstring
  • Loading branch information
klshuster authored and ggdupont committed Jan 22, 2020
1 parent ec1840c commit aa694e0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 24 deletions.
74 changes: 54 additions & 20 deletions parlai/tasks/dstc7/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,28 @@ def __init__(self, opt, shared=None):
filepath = os.path.join(
basedir, 'ubuntu_%s_subtask_1%s.json' % (self.split, self.get_suffix())
)
with open(filepath, 'r') as f:
self.data = json.loads(f.read())

# special case of test set
if self.split == "test":
id_to_res = {}
with open(
os.path.join(basedir, "ubuntu_responses_subtask_1.tsv"), 'r'
) as f:
for line in f:
splited = line[0:-1].split("\t")
id_ = splited[0]
id_res = splited[1]
res = splited[2]
id_to_res[id_] = [{"candidate-id": id_res, "utterance": res}]
for sample in self.data:
sample["options-for-correct-answers"] = id_to_res[
str(sample["example-id"])
]
if shared is not None:
self.data = shared['data']
else:
with open(filepath, 'r') as f:
self.data = json.loads(f.read())

# special case of test set
if self.split == "test":
id_to_res = {}
with open(
os.path.join(basedir, "ubuntu_responses_subtask_1.tsv"), 'r'
) as f:
for line in f:
splited = line[0:-1].split("\t")
id_ = splited[0]
id_res = splited[1]
res = splited[2]
id_to_res[id_] = [{"candidate-id": id_res, "utterance": res}]
for sample in self.data:
sample["options-for-correct-answers"] = id_to_res[
str(sample["example-id"])
]

super().__init__(opt, shared)
self.reset()
Expand Down Expand Up @@ -96,6 +99,37 @@ def share(self):
return shared


class DSTC7TeacherAugmented(DSTC7Teacher):
"""
Augmented Data.
To mimic the way ParlAI generally handles dialogue datasets, the data associated
with this teacher is presented in a format such that a single "episode" is split
across multiple entries.
I.e., suppose we have the following dialogue between speakers 1 and 2:
utterances: [A, B, C, D, E],
label: F
The data in this file is split such that we have the following episodes:
ep1:
utterances: [A],
label: B
ep2:
utterances [A, B, C]
label: D
ep3:
utterances: [A, B, C, D, E],
label: F
"""

def get_suffix(self):
if self.split != "train":
return ""
return "_augmented"


class DSTC7TeacherAugmentedSampled(DSTC7Teacher):
"""
The dev and test set are the same, but the training set has been augmented using the
Expand All @@ -107,7 +141,7 @@ class DSTC7TeacherAugmentedSampled(DSTC7Teacher):
def get_suffix(self):
if self.split != "train":
return ""
return "_augmented"
return "_sampled"

def get_nb_cands(self):
return 16
Expand Down
8 changes: 4 additions & 4 deletions parlai/tasks/dstc7/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@

RESOURCES = [
DownloadableFile(
'http://parl.ai/downloads/dstc7/dstc7.tar.gz',
'dstc7.tar.gz',
'aa3acec0aedb660f1549cdd802f01e5bc9c5b9dc06f10764c5e20686aa4d5571',
'http://parl.ai/downloads/dstc7/dstc7_v2.tgz',
'dstc7_v2.tgz',
'cc8fd830f9894768ab4f7b104cddd4105456812ab614041337ec12c5a3a56685',
)
]


def build(opt):
dpath = os.path.join(opt['datapath'], 'dstc7')
version = None
version = '2.0'

if not build_data.built(dpath, version_string=version):
print('[building data: ' + dpath + ']')
Expand Down

0 comments on commit aa694e0

Please sign in to comment.