Skip to content

Commit 9034ca5

Browse files
committed
Fix typing
1 parent c33865c commit 9034ca5

File tree

15 files changed

+354
-146
lines changed

15 files changed

+354
-146
lines changed

coverage_comment/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from coverage_comment import main
44

55

6-
def main_call(name):
6+
def main_call(name: str):
77
if name == "__main__":
88
main.main()
99

coverage_comment/activity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ActivityNotFound(Exception):
1616
def find_activity(
1717
event_name: str,
1818
is_default_branch: bool,
19-
event_type: str,
19+
event_type: str | None,
2020
is_pr_merged: bool,
2121
) -> str:
2222
"""Find the activity to perform based on the event type and payload."""

coverage_comment/coverage.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import pathlib
88
from collections.abc import Sequence
9+
from typing import Any
910

1011
from coverage_comment import log, subprocess
1112

@@ -21,20 +22,19 @@ class CoverageMetadata:
2122
show_contexts: bool
2223

2324

24-
class OutputMixin:
25-
def as_output(self, prefix: str) -> dict:
26-
data = dataclasses.asdict(self)
27-
output = {}
28-
for key, value in data.items():
29-
if value is not None and not isinstance(value, dict):
30-
output[f"{prefix}_{key}"] = (
31-
float(value) if isinstance(value, decimal.Decimal) else value
32-
)
33-
return output
25+
def as_output(obj: Any, prefix: str) -> dict[str, Any]:
26+
data = dataclasses.asdict(obj)
27+
output: dict[str, Any] = {}
28+
for key, value in data.items():
29+
if value is not None and not isinstance(value, dict):
30+
output[f"{prefix}_{key}"] = (
31+
float(value) if isinstance(value, decimal.Decimal) else value
32+
)
33+
return output
3434

3535

3636
@dataclasses.dataclass(kw_only=True)
37-
class CoverageInfo(OutputMixin):
37+
class CoverageInfo:
3838
covered_lines: int
3939
num_statements: int
4040
percent_covered: decimal.Decimal
@@ -88,7 +88,7 @@ def violation_lines(self) -> list[int]:
8888

8989

9090
@dataclasses.dataclass(kw_only=True)
91-
class DiffCoverage(OutputMixin):
91+
class DiffCoverage:
9292
total_num_lines: int
9393
total_num_violations: int
9494
total_percent_covered: decimal.Decimal
@@ -112,7 +112,7 @@ def compute_coverage(
112112

113113
def get_coverage_info(
114114
merge: bool, coverage_path: pathlib.Path
115-
) -> tuple[dict, Coverage]:
115+
) -> tuple[dict[str, Any], Coverage]:
116116
try:
117117
if merge:
118118
subprocess.run("coverage", "combine", path=coverage_path)
@@ -160,7 +160,7 @@ def generate_coverage_markdown(coverage_path: pathlib.Path) -> str:
160160
)
161161

162162

163-
def _make_coverage_info(data: dict) -> CoverageInfo:
163+
def _make_coverage_info(data: dict[str, Any]) -> CoverageInfo:
164164
"""Build a CoverageInfo object from a "summary" or "totals" key."""
165165
return CoverageInfo(
166166
covered_lines=data["covered_lines"],
@@ -180,7 +180,7 @@ def _make_coverage_info(data: dict) -> CoverageInfo:
180180
)
181181

182182

183-
def extract_info(data: dict, coverage_path: pathlib.Path) -> Coverage:
183+
def extract_info(data: dict[str, Any], coverage_path: pathlib.Path) -> Coverage:
184184
"""
185185
{
186186
"meta": {
@@ -246,7 +246,7 @@ def extract_info(data: dict, coverage_path: pathlib.Path) -> Coverage:
246246
def get_diff_coverage_info(
247247
added_lines: dict[pathlib.Path, list[int]], coverage: Coverage
248248
) -> DiffCoverage:
249-
files = {}
249+
files: dict[pathlib.Path, FileDiffCoverage] = {}
250250
total_num_lines = 0
251251
total_num_violations = 0
252252
num_changed_lines = 0

coverage_comment/files.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
import pathlib
1212
import shutil
1313
import tempfile
14-
from collections.abc import Callable
15-
from typing import Protocol, TypedDict
14+
from typing import Any, Protocol, TypedDict
1615

1716
import httpx
1817

@@ -60,7 +59,7 @@ def apply(self):
6059

6160
def compute_files(
6261
line_rate: decimal.Decimal,
63-
raw_coverage_data: dict,
62+
raw_coverage_data: dict[str, Any],
6463
coverage_path: pathlib.Path,
6564
minimum_green: decimal.Decimal,
6665
minimum_orange: decimal.Decimal,
@@ -97,7 +96,9 @@ def compute_files(
9796

9897

9998
def compute_datafile(
100-
raw_coverage_data: dict, line_rate: decimal.Decimal, coverage_path: pathlib.Path
99+
raw_coverage_data: dict[str, Any],
100+
line_rate: decimal.Decimal,
101+
coverage_path: pathlib.Path,
101102
) -> str:
102103
return json.dumps(
103104
{
@@ -108,7 +109,7 @@ def compute_datafile(
108109
)
109110

110111

111-
def parse_datafile(contents) -> tuple[coverage.Coverage | None, decimal.Decimal]:
112+
def parse_datafile(contents: str) -> tuple[coverage.Coverage | None, decimal.Decimal]:
112113
file_contents = json.loads(contents)
113114
coverage_rate = decimal.Decimal(str(file_contents["coverage"])) / decimal.Decimal(
114115
"100"
@@ -128,7 +129,11 @@ class ImageURLs(TypedDict):
128129
dynamic: str
129130

130131

131-
def get_urls(url_getter: Callable) -> ImageURLs:
132+
class URLGetter(Protocol):
133+
def __call__(self, path: pathlib.Path) -> str: ...
134+
135+
136+
def get_urls(url_getter: URLGetter) -> ImageURLs:
132137
return {
133138
"direct": url_getter(path=BADGE_PATH),
134139
"endpoint": badge.get_endpoint_url(endpoint_url=url_getter(path=ENDPOINT_PATH)),
@@ -137,7 +142,9 @@ def get_urls(url_getter: Callable) -> ImageURLs:
137142

138143

139144
def get_coverage_html_files(
140-
*, coverage_path: pathlib.Path, gen_dir: pathlib.Path = pathlib.Path("/tmp")
145+
*,
146+
coverage_path: pathlib.Path,
147+
gen_dir: pathlib.Path | None = None,
141148
) -> ReplaceDir:
142149
html_dir = pathlib.Path(tempfile.mkdtemp(dir=gen_dir))
143150
coverage.generate_coverage_html_files(

coverage_comment/github.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import re
88
import sys
99
import zipfile
10+
from collections.abc import Iterable
1011
from typing import Any
1112
from urllib.parse import urlparse
1213

@@ -106,7 +107,9 @@ def download_artifact(
106107
raise NoArtifact(f"File named {filename} not found in artifact {artifact_name}")
107108

108109

109-
def _fetch_artifacts(repo_path, run_id):
110+
def _fetch_artifacts(
111+
repo_path: github_client.Endpoint, run_id: int
112+
) -> Iterable[github_client.JsonObject]:
110113
page = 1
111114
total_fetched = 0
112115

@@ -143,31 +146,13 @@ def find_pr_for_branch(
143146

144147
full_branch = f"{owner}:{branch}"
145148

146-
common_kwargs = {"head": full_branch, "sort": "updated", "direction": "desc"}
147-
try:
148-
return next(
149-
iter(
150-
pr.number
151-
for pr in github.repos(repository).pulls.get(
152-
state="open", **common_kwargs
153-
)
154-
)
155-
)
156-
except StopIteration:
157-
pass
158-
log.info(f"No open PR found for branch {branch}, defaulting to all PRs")
149+
for state in ["open", "all"]:
150+
for pr in github.repos(repository).pulls.get(
151+
state=state, head=full_branch, sort="updated", direction="desc"
152+
):
153+
return pr.number # pyright: ignore
159154

160-
try:
161-
return next(
162-
iter(
163-
pr.number
164-
for pr in github.repos(repository).pulls.get(
165-
state="all", **common_kwargs
166-
)
167-
)
168-
)
169-
except StopIteration:
170-
raise CannotDeterminePR(f"No open PR found for branch {branch}")
155+
raise CannotDeterminePR(f"No open PR found for branch {branch}")
171156

172157

173158
def get_my_login(github: github_client.GitHub) -> str:
@@ -195,10 +180,14 @@ def post_comment(
195180
comments_path = github.repos(repository).issues.comments
196181

197182
for comment in issue_comments_path.get():
198-
if comment.user.login == me and marker in comment.body:
183+
login: str = comment.user.login # pyright: ignore
184+
body: str = comment.body # pyright: ignore
185+
comment_id: int = comment.id # pyright: ignore
186+
187+
if login == me and marker in body:
199188
log.info("Update previous comment")
200189
try:
201-
comments_path(comment.id).patch(body=contents)
190+
comments_path(comment_id).patch(body=contents)
202191
except github_client.Forbidden as exc:
203192
raise CannotPostComment from exc
204193
break
@@ -296,7 +285,7 @@ def get_pr_diff(github: github_client.GitHub, repository: str, pr_number: int) -
296285
return (
297286
github.repos(repository)
298287
.pulls(pr_number)
299-
.get(headers={"Accept": "application/vnd.github.v3.diff"})
288+
.get(headers={"Accept": "application/vnd.github.v3.diff"}, text=True)
300289
)
301290

302291

@@ -309,5 +298,5 @@ def get_branch_diff(
309298
return (
310299
github.repos(repository)
311300
.compare(f"{base_branch}...{head_branch}")
312-
.get(headers={"Accept": "application/vnd.github.v3.diff"})
301+
.get(headers={"Accept": "application/vnd.github.v3.diff"}, text=True)
313302
)

0 commit comments

Comments
 (0)