Skip to content
Draft

first #240

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/dolma/cli/deduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def run(cls, parsed_config: DeduperConfig):
# perform some path validation to make sure we don't call the mixer with invalid config
total_matching_documents = 0
for document in parsed_config.documents:

if not any(
fnmatch.fnmatch(dict_config["dedupe"]["document_dir"], part) for part in document.split(os.sep)
):
Expand Down
5 changes: 4 additions & 1 deletion python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class StreamConfig:
"from the file extension."
),
)
document_dir: str = field(
default="documents", help="Folder in source path to replace with 'attributes' when looking for attributes"
)


@dataclass
Expand Down Expand Up @@ -145,7 +148,6 @@ def run(cls, parsed_config: MixerConfig):
# perform some path validation to make sure we don't call the mixer with invalid config
total_matching_documents = 0
for document in stream_config.documents:

current_matching_documents = sum(1 for _ in glob_path(document))
if current_matching_documents == 0:
# only raise a warning if no documents are found for a single path
Expand All @@ -159,6 +161,7 @@ def run(cls, parsed_config: MixerConfig):
# populate the stream config dict
stream_config_dict["name"] = stream_config.name
stream_config_dict["documents"] = [str(d) for d in stream_config.documents]
stream_config_dict["document_dir"] = stream_config.document_dir
stream_config_dict["attributes"] = [str(a) for a in list(stream_config.attributes)]
stream_config_dict["output"] = {
"path": str(stream_config.output.path),
Expand Down
5 changes: 5 additions & 0 deletions python/dolma/cli/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class TaggerConfig:
default=False,
help="If true, only print the configuration and exit without running the taggers.",
)
document_dir: str = field(
default="documents",
help="The folder in source paths to replace with 'attributes' to store results, if not 'documents'",
)


class TaggerCli(BaseCli):
Expand Down Expand Up @@ -140,6 +144,7 @@ def run(cls, parsed_config: TaggerConfig):
profile_output=parsed_config.profile.output,
profile_steps=parsed_config.profile.steps,
profile_sort_key=parsed_config.profile.sort_key,
document_dir=parsed_config.document_dir,
)


Expand Down
3 changes: 2 additions & 1 deletion python/dolma/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def create_and_run_tagger(
profile_steps: Optional[int] = None,
profile_sort_key: str = "tottime",
profile_lines: int = 100,
document_dir: str = "documents",
):
"""This function creates a tagger and runs it on a list of documents.

Expand Down Expand Up @@ -444,7 +445,7 @@ def create_and_run_tagger(

if destination is None:
try:
destination = _make_paths_from_substitution(documents, "documents", f"attributes/{experiment}")
destination = _make_paths_from_substitution(documents, document_dir, f"attributes/{experiment}")
except Exception as exp:
raise RuntimeError("Could not make destination paths from documents paths") from exp
elif isinstance(destination, str):
Expand Down
6 changes: 6 additions & 0 deletions python/dolma/warc/linearizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,9 @@ def linearize(self, content: Union[str, bytes]) -> str:
)
self._flush()
return output or ""


@LinearizerRegistry.add("no-op")
class NoOpLinearizer(BaseLinearizer):
def linearize(self, content: Union[str, bytes]) -> str:
return str(content)
1 change: 1 addition & 0 deletions python/dolma/warc/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def create_and_run_warc_pipeline(
store_html_in_metadata: bool = False,
skip_no_pre_taggers: bool = False,
skip_no_post_taggers: bool = False,
skip_linearization: bool = False,
):
with ExitStack() as stack:
if metadata is None:
Expand Down
18 changes: 14 additions & 4 deletions src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ impl Shard {
pub fn split_streams(streams: &Vec<StreamConfig>) -> Result<Vec<Shard>, IoError> {
let mut shards: Vec<Shard> = Vec::new();
for stream_config in streams {
let document_dir = format!(
"/{}/",
stream_config.document_dir.as_deref().unwrap_or("documents")
);
let mut stream_shard_count = 0;
log::info!("Computing shards for stream {}...", stream_config.name);
let stream_inputs = find_objects_matching_patterns(&stream_config.documents)?;
Expand All @@ -50,7 +54,7 @@ impl Shard {
let mut attr_paths = Vec::new();
for prefix in stream_config.attributes.iter() {
let attr_prefix = format!("/attributes/{}/", prefix);
let attr_path = input.replace("/documents/", &attr_prefix);
let attr_path = input.replace(&document_dir, &attr_prefix);
attr_paths.push(attr_path);
}
(
Expand Down Expand Up @@ -135,13 +139,17 @@ impl Shard {
// dataset is a strict subset of the original and is intended to be unshuffled and unsharded.
let mut shards: Vec<Shard> = Vec::new();
for stream_config in streams {
let document_dir = format!(
"/{}/",
stream_config.document_dir.as_deref().unwrap_or("documents")
);
let stream_inputs = find_objects_matching_patterns(&stream_config.documents)?;
let input_count = stream_inputs.len();
let inputs = stream_inputs.into_iter().map(|input| {
let mut attr_paths = Vec::new();
for prefix in stream_config.attributes.iter() {
let attr_prefix = format!("/attributes/{}/", prefix);
let attr_path = input.replace("/documents/", &attr_prefix);
let attr_path = input.replace(&document_dir, &attr_prefix);
attr_paths.push(attr_path);
}
DocumentPaths {
Expand All @@ -152,10 +160,11 @@ impl Shard {

for input in inputs {
let doc_path_clone = input.doc_path.clone();
let output_suffix = doc_path_clone.split("/documents/").last().unwrap();
let output_suffix = doc_path_clone.split(&document_dir).last().unwrap();
let output = format!(
"{}/documents/{}",
"{}{}{}",
stream_config.output.path.clone(),
document_dir,
output_suffix
);
log::info!("Creating shard for {}", output);
Expand Down Expand Up @@ -543,6 +552,7 @@ pub mod shard_config {
pub span_replacement: Option<Vec<SpanReplacementConfig>>,
pub output: StreamOutputConfig,
pub compression: Option<CompressionConfig>,
pub document_dir: Option<String>,
}

#[derive(Serialize, Deserialize, Clone)]
Expand Down
34 changes: 34 additions & 0 deletions tests/config/alt-path-mixer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"streams": [
{
"name": "mixer-test",
"documents": [
"tests/data/provided/alternative_term/*.gz"
],
"document_dir":"alternative_term",
"output": {
"path": "tests/work/output/mixer",
"max_size_in_bytes": 100000
},
"attributes": [
"pii",
"toxicity"
],
"filter": {
"include": [
"$.metadata[?(@.length < 10000)]"
],
"exclude": [
"$.metadata[?(@.length < 500)]",
"$.attributes[?(@.pii.too_much_pii == true)]",
"$.attributes[?(@.toxicity > 0.8)]"
]
}
}
],
"work_dir": {
"input": "tests/work/temp/mixer/input",
"output": "tests/work/temp/mixer/output"
},
"processes": 1
}
Binary file added tests/data/provided/alternative_term/000.json.gz
Binary file not shown.
31 changes: 31 additions & 0 deletions tests/python/test_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
EMAIL_SPANS_JQ = Path(__file__).parent.parent / "config/email-spans-jq.yaml"
FILTER_BY_SPANS = Path(__file__).parent.parent / "config/filter-by-spans.json"
MIXER = Path(__file__).parent.parent / "config/mixer.json"
ALT_DOC_PATH_MIXER = Path(__file__).parent.parent / "config/alt-path-mixer.json"

PARAGRAPH_SPANS = Path(__file__).parent.parent / "config/paragraph-spans.json"


Expand Down Expand Up @@ -150,6 +152,35 @@ def test_remote_input_remote_output(self):
provided = self.checkAndRemoveProvenance(provided)
self.assertEqual(expected, provided)

def test_alt_doc_path_mixer(self):
if self.remote_test_prefix is None:
return self.skipTest("Skipping AWS tests")

with open(ALT_DOC_PATH_MIXER, mode="r", encoding="utf8") as f:
config = json.load(f)

# keep track of local output path
local_input = config["streams"][0]["documents"][0]
local_output = config["streams"][0]["output"]["path"]

# replace results path with s3 path
config["streams"][0]["output"]["path"] = f"{self.remote_test_prefix}/{local_output}"

# upload local input to s3, replace local input with s3 path
config["streams"][0]["documents"][0] = f"{self.remote_test_prefix}/{local_input}"

with NamedTemporaryFile("w") as f:
json.dump(config, f)
f.flush()

main(argv=["-c", f.name, "mix"])

download_s3_prefix(f"{self.remote_test_prefix}/tests/work", "tests/work/remote")
expected = load_jsonl("tests/data/expected/mixer.json.gz")
provided = load_jsonl("tests/work/remote/output/mixer/mixer-test-0000.json.gz")
provided = self.checkAndRemoveProvenance(provided)
self.assertEqual(expected, provided)

def test_remote_input_local_output(self):
if self.remote_test_prefix is None:
return self.skipTest("Skipping AWS tests")
Expand Down
2 changes: 0 additions & 2 deletions tests/python/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ def test_split_glob(self):

class TestSplitExt(TestCase):
def test_file(self):

prot, parts, ext = split_ext("file.txt")

self.assertEqual(prot, "")
Expand All @@ -318,7 +317,6 @@ def test_file(self):
self.assertEqual(ext, ".")

def test_path(self):

prot, parts, ext = split_ext("path/to/file.txt")

self.assertEqual(prot, "")
Expand Down
56 changes: 56 additions & 0 deletions tests/python/test_warc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,59 @@ def test_pretag_html(self):
{"by_4_0", "by_3_0"},
)
self.assertIn("cc_re__cc_re__cc_by_4_0", sample1[2]["attributes"])

def test_skip_linearization(self):
"""Test that when skip_linearization is True, the raw HTML content is preserved."""
outputs = self._run_pipeline_with_skip_linearization()
self.assertEqual(len(outputs), 2)
self.assertIn("sample-0000.jsonl.gz", outputs)
self.assertIn("sample-0001.jsonl.gz", outputs)

sample0 = outputs["sample-0000.jsonl.gz"]
sample1 = outputs["sample-0001.jsonl.gz"]

# Check that we got some documents
self.assertGreater(len(sample0), 0)
self.assertGreater(len(sample1), 0)

# For all documents, verify they contain raw HTML instead of linearized text
for sample in chain(sample0, sample1):
# HTML content should be in the text field
self.assertIn("<", sample["text"])
self.assertIn(">", sample["text"])

# Common HTML tags that should be present in raw HTML
html_indicators = ["<html", "<body", "<div", "<p"]
self.assertTrue(any(indicator in sample["text"].lower() for indicator in html_indicators))

# Basic metadata should still be present
self.assertEqual(sample["version"], "v0")
self.assertEqual(sample["source"], "test")
self.assertIn("warc_url", sample["metadata"])
self.assertIn("url", sample["metadata"])
self.assertIn("warc_date", sample["metadata"])
self.assertIn("warc_filename", sample["metadata"])
self.assertIn("content_type", sample["metadata"])

def _run_pipeline_with_skip_linearization(self) -> Dict[str, List[dict]]:
"""Helper method to run pipeline with skip_linearization=True."""
create_and_run_warc_pipeline(
documents=[f"{DATA_PATH}/*.warc.gz"],
destination=[self.tempdir],
num_processes=1,
ignore_existing=False,
debug=True,
source_name="test",
skip_no_pre_taggers=False,
skip_no_post_taggers=False,
store_html_in_metadata=False,
linearizer_name="no-op",
pre_taggers=["cc_re"],
post_taggers=["lingua_1e2"],
)
outputs: Dict[str, List[dict]] = {}
for fn in os.listdir(self.tempdir):
with smart_open.open(os.path.join(self.tempdir, fn), mode="rt", encoding="utf-8") as f:
for ln in f:
outputs.setdefault(fn, []).append(json.loads(ln))
return outputs
13 changes: 8 additions & 5 deletions tests/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def skip_aws_tests() -> bool:
return (dolma_tests_skip or "false").lower() == "true"


def upload_test_documents(local_input: str, test_prefix: str) -> Tuple[str, str]:
remote_input = f"{test_prefix}/input/documents"
remote_output = f"{test_prefix}/output/documents"
def upload_test_documents(local_input: str, test_prefix: str, document_dir: str = "documents") -> Tuple[str, str]:
remote_input = f"{test_prefix}/input/{document_dir}"
remote_output = f"{test_prefix}/output/{document_dir}"

for i, local_fp in enumerate(glob_path(local_input)):
remote_fp = f"{remote_input}/{i:05d}.json.gz"
Expand Down Expand Up @@ -127,6 +127,7 @@ def upload_s3_prefix(s3_prefix: str, local_prefix: str):
bucket_name, prefix = parse_s3_path(s3_prefix)

for local_fp in glob_path(local_prefix):
print(f"LOCAL_FP {local_fp}")
name = local_fp.replace(local_prefix, "").lstrip("/")
s3.upload_file(Bucket=bucket_name, Key=f"{prefix}/{name}", Filename=local_fp)

Expand Down Expand Up @@ -167,9 +168,11 @@ def writeUnits(

return [str(p) for p in file_paths]

def writeDocs(self, docs: List[str], partitions: int = 1, ext_dir: Optional[Path] = None) -> List[str]:
def writeDocs(
self, docs: List[str], partitions: int = 1, ext_dir: Optional[Path] = None, unit_type: str = "documents"
) -> List[str]:
encoded_docs = [{"id": str(i), "text": d, "source": __file__} for i, d in enumerate(docs)]
return self.writeUnits(units=encoded_docs, unit_type="documents", partitions=partitions, ext_dir=ext_dir)
return self.writeUnits(units=encoded_docs, unit_type=unit_type, partitions=partitions, ext_dir=ext_dir)

def writeAttributes(
self,
Expand Down