xref: /aosp_15_r20/tools/asuite/atest/logstorage/atest_gcp_utils.py (revision c2e18aaa1096c836b086f94603d04f4eb9cf37f5)
1# Copyright (C) 2020 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utility functions for atest."""
15from __future__ import print_function
16
17import getpass
18import logging
19import os
20import pathlib
21from pathlib import Path
22from socket import socket
23import subprocess
24import time
25from typing import Any, Callable
26import uuid
27
28from atest import atest_utils
29from atest import constants
30from atest.atest_enum import DetectType
31from atest.metrics import metrics
32import httplib2
33from oauth2client import client as oauth2_client
34from oauth2client import contrib as oauth2_contrib
35from oauth2client import tools as oauth2_tools
36
37
38class RunFlowFlags:
39  """Flags for oauth2client.tools.run_flow."""
40
41  def __init__(self, browser_auth):
42    self.auth_host_port = [8080, 8090]
43    self.auth_host_name = 'localhost'
44    self.logging_level = 'ERROR'
45    self.noauth_local_webserver = not browser_auth
46
47
48class GCPHelper:
49  """GCP bucket helper class."""
50
51  def __init__(
52      self,
53      client_id=None,
54      client_secret=None,
55      user_agent=None,
56      scope=constants.SCOPE_BUILD_API_SCOPE,
57  ):
58    """Init stuff for GCPHelper class.
59
60    Args:
61        client_id: String, client id from the cloud project.
62        client_secret: String, client secret for the client_id.
63        user_agent: The user agent for the credential.
64        scope: String, scopes separated by space.
65    """
66    self.client_id = client_id
67    self.client_secret = client_secret
68    self.user_agent = user_agent
69    self.scope = scope
70
71  def get_refreshed_credential_from_file(self, creds_file_path):
72    """Get refreshed credential from file.
73
74    Args:
75        creds_file_path: Credential file path.
76
77    Returns:
78        An oauth2client.OAuth2Credentials instance.
79    """
80    credential = self.get_credential_from_file(creds_file_path)
81    if credential:
82      try:
83        credential.refresh(httplib2.Http())
84      except oauth2_client.AccessTokenRefreshError as e:
85        logging.debug('Token refresh error: %s', e)
86      if not credential.invalid:
87        return credential
88    logging.debug('Cannot get credential.')
89    return None
90
91  def get_credential_from_file(self, creds_file_path):
92    """Get credential from file.
93
94    Args:
95        creds_file_path: Credential file path.
96
97    Returns:
98        An oauth2client.OAuth2Credentials instance.
99    """
100    storage = oauth2_contrib.multiprocess_file_storage.get_credential_storage(
101        filename=os.path.abspath(creds_file_path),
102        client_id=self.client_id,
103        user_agent=self.user_agent,
104        scope=self.scope,
105    )
106    return storage.get()
107
108  def get_credential_with_auth_flow(self, creds_file_path):
109    """Get Credential object from file.
110
111    Get credential object from file. Run oauth flow if haven't authorized
112    before.
113
114    Args:
115        creds_file_path: Credential file path.
116
117    Returns:
118        An oauth2client.OAuth2Credentials instance.
119    """
120    credentials = None
121    # SSO auth
122    try:
123      token = self._get_sso_access_token()
124      credentials = oauth2_client.AccessTokenCredentials(token, 'atest')
125      if credentials:
126        return credentials
127    # pylint: disable=broad-except
128    except Exception as e:
129      logging.debug('Exception:%s', e)
130    # GCP auth flow
131    credentials = self.get_refreshed_credential_from_file(creds_file_path)
132    if not credentials:
133      storage = oauth2_contrib.multiprocess_file_storage.get_credential_storage(
134          filename=os.path.abspath(creds_file_path),
135          client_id=self.client_id,
136          user_agent=self.user_agent,
137          scope=self.scope,
138      )
139      return self._run_auth_flow(storage)
140    return credentials
141
142  def _run_auth_flow(self, storage):
143    """Get user oauth2 credentials.
144
145    Using the loopback IP address flow for desktop clients.
146
147    Args:
148        storage: GCP storage object.
149
150    Returns:
151        An oauth2client.OAuth2Credentials instance.
152    """
153    flags = RunFlowFlags(browser_auth=True)
154
155    # Get a free port on demand.
156    port = None
157    while not port or port < 10000:
158      with socket() as local_socket:
159        local_socket.bind(('', 0))
160        _, port = local_socket.getsockname()
161    _localhost_port = port
162    _direct_uri = f'http://localhost:{_localhost_port}'
163    flow = oauth2_client.OAuth2WebServerFlow(
164        client_id=self.client_id,
165        client_secret=self.client_secret,
166        scope=self.scope,
167        user_agent=self.user_agent,
168        redirect_uri=f'{_direct_uri}',
169    )
170    credentials = oauth2_tools.run_flow(flow=flow, storage=storage, flags=flags)
171    return credentials
172
173  @staticmethod
174  def _get_sso_access_token():
175    """Use stubby command line to exchange corp sso to a scoped oauth
176
177    token.
178
179    Returns:
180        A token string.
181    """
182    if not constants.TOKEN_EXCHANGE_COMMAND:
183      return None
184
185    request = constants.TOKEN_EXCHANGE_REQUEST.format(
186        user=getpass.getuser(), scope=constants.SCOPE
187    )
188    # The output format is: oauth2_token: "<TOKEN>"
189    return subprocess.run(
190        constants.TOKEN_EXCHANGE_COMMAND,
191        input=request,
192        check=True,
193        text=True,
194        shell=True,
195        stdout=subprocess.PIPE,
196    ).stdout.split('"')[1]
197
198
199# TODO: The usage of build_client should be removed from this method because
200# it's not related to this module. For now, we temporarily declare the return
201# type hint for build_client_creator to be Any to avoid circular importing.
202def do_upload_flow(
203    extra_args: dict[str, str],
204    build_client_creator: Callable,
205    invocation_properties: dict[str, str] = None,
206) -> tuple:
207  """Run upload flow.
208
209  Asking user's decision and do the related steps.
210
211  Args:
212      extra_args: Dict of extra args to add to test run.
213      build_client_creator: A function that takes a credential and returns a
214        BuildClient object.
215      invocation_properties: Additional invocation properties to write into the
216        invocation.
217
218  Return:
219      A tuple of credential object and invocation information dict.
220  """
221  invocation_properties = invocation_properties or {}
222  fetch_cred_start = time.time()
223  creds = fetch_credential()
224  metrics.LocalDetectEvent(
225      detect_type=DetectType.FETCH_CRED_MS,
226      result=int((time.time() - fetch_cred_start) * 1000),
227  )
228  if creds:
229    prepare_upload_start = time.time()
230    build_client = build_client_creator(creds)
231    inv, workunit, local_build_id, build_target = _prepare_data(
232        build_client, invocation_properties
233    )
234    metrics.LocalDetectEvent(
235        detect_type=DetectType.UPLOAD_PREPARE_MS,
236        result=int((time.time() - prepare_upload_start) * 1000),
237    )
238    extra_args[constants.INVOCATION_ID] = inv['invocationId']
239    extra_args[constants.WORKUNIT_ID] = workunit['id']
240    extra_args[constants.LOCAL_BUILD_ID] = local_build_id
241    extra_args[constants.BUILD_TARGET] = build_target
242    if not os.path.exists(os.path.dirname(constants.TOKEN_FILE_PATH)):
243      os.makedirs(os.path.dirname(constants.TOKEN_FILE_PATH))
244    with open(constants.TOKEN_FILE_PATH, 'w') as token_file:
245      if creds.token_response:
246        token_file.write(creds.token_response['access_token'])
247      else:
248        token_file.write(creds.access_token)
249    return creds, inv
250  return None, None
251
252
253def fetch_credential():
254  """Fetch the credential object."""
255  creds_path = atest_utils.get_config_folder().joinpath(
256      constants.CREDENTIAL_FILE_NAME
257  )
258  return GCPHelper(
259      client_id=constants.CLIENT_ID,
260      client_secret=constants.CLIENT_SECRET,
261      user_agent='atest',
262  ).get_credential_with_auth_flow(creds_path)
263
264
265def _prepare_data(client, invocation_properties: dict[str, str]):
266  """Prepare data for build api using.
267
268  Args:
269      build_client: The logstorage_utils.BuildClient object.
270      invocation_properties: Additional invocation properties to write into the
271        invocation.
272
273  Return:
274      invocation and workunit object.
275      build id and build target of local build.
276  """
277  try:
278    logging.disable(logging.INFO)
279    external_id = str(uuid.uuid4())
280    branch = _get_branch(client)
281    target = _get_target(branch, client)
282    build_record = client.insert_local_build(external_id, target, branch)
283    client.insert_build_attempts(build_record)
284    invocation = client.insert_invocation(build_record, invocation_properties)
285    workunit = client.insert_work_unit(invocation)
286    return invocation, workunit, build_record['buildId'], target
287  finally:
288    logging.disable(logging.NOTSET)
289
290
291def _get_branch(build_client):
292  """Get source code tree branch.
293
294  Args:
295      build_client: The build client object.
296
297  Return:
298      "git_main" in internal git, "aosp-main" otherwise.
299  """
300  default_branch = 'git_main' if constants.CREDENTIAL_FILE_NAME else 'aosp-main'
301  local_branch = 'git_%s' % atest_utils.get_manifest_branch()
302  branch = build_client.get_branch(local_branch)
303  return local_branch if branch else default_branch
304
305
306def _get_target(branch, build_client):
307  """Get local build selected target.
308
309  Args:
310      branch: The branch want to check.
311      build_client: The build client object.
312
313  Return:
314      The matched build target, "aosp_x86_64-trunk_staging-userdebug"
315      otherwise.
316  """
317  default_target = 'aosp_x86_64-trunk_staging-userdebug'
318  local_target = atest_utils.get_build_target()
319  targets = [t['target'] for t in build_client.list_target(branch)['targets']]
320  return local_target if local_target in targets else default_target
321