MasumBhuiyan commited on
Commit
6fc394f
·
1 Parent(s): e5f704a

data.py updated to save processed data

Browse files
data/bn_multi_tribe_mt.txt ADDED
The diff for this file is too large to render. See raw diff
 
src/pipes/data.py CHANGED
@@ -1,6 +1,7 @@
1
  import string
2
  import os
3
  import random
 
4
 
5
  class Dataset:
6
  def __init__(self, sentences, vocab=None):
@@ -10,7 +11,6 @@ class Dataset:
10
  vocab = set()
11
  for sentence in sentences:
12
  vocab.update(sentence.split())
13
- print("Len", len(list(vocab)))
14
  return sorted(list(vocab))
15
 
16
  def remove_punctuation(self, sentence):
@@ -42,22 +42,54 @@ class Dataset:
42
 
43
  self.vocab = self.build_vocab(self.sentences)
44
 
45
- if max_length is None:
46
- max_length = max(len(sentence.split()) for sentence in self.sentences)
47
 
48
  processed_sentences = []
49
  for sentence in self.sentences:
50
  tokens = self.tokenize(sentence)
51
- padded_tokens = self.pad_sequence(tokens, max_length)
52
  processed_sentences.append(padded_tokens)
53
 
54
- return processed_sentences
 
 
 
 
55
 
 
 
 
 
 
 
56
 
57
- def load_data(file_path):
58
- sentences = []
59
- with open(file_path, 'r', encoding='utf-8') as f:
60
- sentences = f.readlines()
61
  dataset = Dataset(sentences)
62
- processed_sentences = dataset.process()
63
- return processed_sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import string
2
  import os
3
  import random
4
+ import utils
5
 
6
  class Dataset:
7
  def __init__(self, sentences, vocab=None):
 
11
  vocab = set()
12
  for sentence in sentences:
13
  vocab.update(sentence.split())
 
14
  return sorted(list(vocab))
15
 
16
  def remove_punctuation(self, sentence):
 
42
 
43
  self.vocab = self.build_vocab(self.sentences)
44
 
45
+ self.max_length = max(len(sentence.split()) for sentence in self.sentences)
 
46
 
47
  processed_sentences = []
48
  for sentence in self.sentences:
49
  tokens = self.tokenize(sentence)
50
+ padded_tokens = self.pad_sequence(tokens, self.max_length)
51
  processed_sentences.append(padded_tokens)
52
 
53
+ data_dict = {}
54
+ data_dict["max_seq_len"] = self.max_length
55
+ data_dict["vocab_size"] = len(self.vocab)
56
+ data_dict["vocab"] = self.vocab
57
+ return processed_sentences, data_dict
58
 
59
+ if __name__ == "__main__":
60
+ # Process
61
+ # gr
62
+ sentences = utils.read_txt("E:/bn_multi_tribe_mt/data/raw/gr.txt")
63
+ dataset = Dataset(sentences)
64
+ gr, gr_dict = dataset.process()
65
 
66
+ # bn
67
+ sentences = utils.read_txt("E:/bn_multi_tribe_mt/data/raw/bn.txt")
 
 
68
  dataset = Dataset(sentences)
69
+ bn, bn_dict = dataset.process()
70
+
71
+ #Shuffle
72
+ zipped = list(zip(gr, bn))
73
+ random.shuffle(zipped)
74
+ gr, bn = zip(*zipped)
75
+
76
+ # Split
77
+ split_id = int(len(gr) * 0.8)
78
+
79
+ gr_train = gr[:split_id]
80
+ gr_val = gr[split_id:]
81
+ bn_train = bn[:split_id]
82
+ bn_val = bn[split_id:]
83
+
84
+ # Save
85
+ gr_dict["train"] = gr_train
86
+ gr_dict["val"] = gr_val
87
+ bn_dict["train"] = bn_train
88
+ bn_dict["val"] = bn_val
89
+
90
+ data_dict = {}
91
+ data_dict["gr"] = gr_dict
92
+ data_dict["bn"] = bn_dict
93
+ utils.save_dict("E:/bn_multi_tribe_mt/data/bn_multi_tribe_mt.txt", data_dict)
94
+
95
+ # print("Loaded dict: ", utils.load_dict("E:/bn_multi_tribe_mt/data/bn_multi_tribe_mt.txt")["bn"]["val"])
src/pipes/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def save_dict(file_path, my_dict, encoding='utf-8'):
4
+ with open(file_path, "w", encoding=encoding) as f:
5
+ json.dump(my_dict, f, ensure_ascii=False)
6
+
7
+ def load_dict(file_path):
8
+ with open(file_path, "r", encoding='utf-8') as f:
9
+ my_dict = json.load(f)
10
+ return my_dict
11
+
12
+ def read_txt(file_path):
13
+ sentences = []
14
+ with open(file_path, "r", encoding="utf-8") as f:
15
+ sentences = f.readlines()
16
+ return sentences