Added __eq__ method to Checkpoint class

Signed-off-by: sdimovv <36302090+sdimovv@users.noreply.github.com>
This commit is contained in:
sdimovv 2023-02-07 15:21:56 +00:00 committed by GitHub
parent 91d4484684
commit c7bf137e5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 2 deletions

View File

@ -33,7 +33,11 @@ class Checkpoint(Model):
checkpoint=self.id,
checkpoint_dir=self.collection.checkpoint_dir,
)
def __eq__(self, other):
if isinstance(other, Checkpoint):
return self.id == other.id
return self.id == other
class CheckpointCollection(Collection):
"""(Experimental)."""
@ -94,7 +98,7 @@ class CheckpointCollection(Collection):
checkpoints = self.list()
for checkpoint in checkpoints:
if checkpoint.id == id:
if checkpoint == id:
return checkpoint
raise CheckpointNotFound(