Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import json
2from urllib.parse import urlparse
3import time
4import uuid
5from datetime import datetime, timedelta
6import os
7from termcolor import colored, cprint
8import readline
10import oci
11from oci_core import get_os_client, get_df_client, os_upload, os_upload_json, os_download, os_download_json , \
12 os_has_object, os_delete_object
14from spark_etl.job_submitters import AbstractJobSubmitter
15from spark_etl import SparkETLLaunchFailure, SparkETLKillFailure
16from .tools import check_response, remote_execute
17from spark_etl.utils import CLIHandler, handle_server_ask
18from spark_etl.core import ClientChannelInterface
20class ClientChannel(ClientChannelInterface):
21 def __init__(self, region, oci_config, run_base_dir, run_id):
22 self.region = region
23 self.oci_config = oci_config
24 self.run_base_dir = run_base_dir
25 self.run_id = run_id
27 o = urlparse(run_base_dir)
28 self.namespace = o.netloc.split('@')[1]
29 self.bucket = o.netloc.split('@')[0]
30 self.root_path = o.path[1:] # remove the leading "/"
33 def read_json(self, name):
34 os_client = get_os_client(self.region, self.oci_config)
35 object_name = os.path.join(self.root_path, self.run_id, name)
36 result = os_download_json(os_client, self.namespace, self.bucket, object_name)
37 return result
40 def write_json(self, name, payload):
41 os_client = get_os_client(self.region, self.oci_config)
42 object_name = os.path.join(self.root_path, self.run_id, name)
43 os_upload_json(os_client, payload, self.namespace, self.bucket, object_name)
46 def has_json(self, name):
47 os_client = get_os_client(self.region, self.oci_config)
48 object_name = os.path.join(self.root_path, self.run_id, name)
49 return os_has_object(os_client, self.namespace, self.bucket, object_name)
52 def delete_json(self, name):
53 os_client = get_os_client(self.region, self.oci_config)
54 object_name = os.path.join(self.root_path, self.run_id, name)
55 os_delete_object(os_client, self.namespace, self.bucket, object_name)
58class DataflowJobSubmitter(AbstractJobSubmitter):
59 def __init__(self, config):
60 super(DataflowJobSubmitter, self).__init__(config)
61 # config fields
62 # region, e.g. IAD
63 # run_base_dir, uri, point to the run directory.
64 run_base_dir = self.config['run_base_dir']
65 o = urlparse(run_base_dir)
66 if o.scheme != 'oci':
67 raise SparkETLLaunchFailure("run_base_dir must be in OCI")
70 @property
71 def region(self):
72 return self.config['region']
76 def run(self, deployment_location, options={}, args={}, handlers=None, on_job_submitted=None, cli_mode=False):
77 # options fields
78 # num_executors : number
79 # driver_shape : string
80 # executor_shape : string
81 # lib_url_duration : number (repre the number of minutes)
82 # on_job_submitted : callback, on_job_submitted(run_id, vendor_info={'oci_run_id': 'xxxyyy'})
84 o = urlparse(deployment_location)
85 if o.scheme != 'oci':
86 raise SparkETLLaunchFailure("deployment_location must be in OCI")
88 run_base_dir = self.config['run_base_dir']
89 run_id = str(uuid.uuid4())
91 namespace = o.netloc.split('@')[1]
92 bucket = o.netloc.split('@')[0]
93 root_path = o.path[1:] # remove the leading "/"
95 # let's get the deployment.json
96 os_client = get_os_client(self.region, self.config.get("oci_config"))
97 deployment = os_download_json(os_client, namespace, bucket, os.path.join(root_path, "deployment.json"))
99 # let's upload the args
100 client_channel = ClientChannel(
101 self.region,
102 self.config.get("oci_config"),
103 run_base_dir,
104 run_id
105 )
106 client_channel.write_json("args.json", args)
108 o = urlparse(self.config['run_base_dir'])
109 namespace = o.netloc.split('@')[1]
110 bucket = o.netloc.split('@')[0]
111 root_path = o.path[1:] # remove the leading "/"
112 os_upload_json(os_client, args, namespace, bucket, f"{root_path}/{run_id}/args.json")
114 df_client = get_df_client(self.region, self.config.get("oci_config"))
115 crd_argv = {
116 'compartment_id': deployment['compartment_id'],
117 'application_id': deployment['application_id'],
118 'display_name' :options["display_name"],
119 'arguments': [
120 "--deployment-location", deployment_location,
121 "--run-id", run_id,
122 "--run-dir", os.path.join(run_base_dir, run_id),
123 "--app-region", self.region,
124 ],
125 }
126 for key in ['num_executors', 'driver_shape', 'executor_shape']:
127 if key in options:
128 crd_argv[key] = options[key]
130 create_run_details = oci.data_flow.models.CreateRunDetails(**crd_argv)
131 r = df_client.create_run(create_run_details=create_run_details)
132 check_response(r, lambda : SparkETLLaunchFailure("dataflow failed to run the application"))
133 run = r.data
134 oci_run_id = run.id
135 print(f"Job launched, run_id = {run_id}, oci_run_id = {run.id}")
136 if on_job_submitted is not None:
137 on_job_submitted(run_id, vendor_info={'oci_run_id': run.id})
139 cli_entered = False
140 while True:
141 time.sleep(10)
142 r = df_client.get_run(run_id=run.id)
143 check_response(r, lambda : SparkETLGetStatusFailure("dataflow failed to get run status"))
144 run = r.data
145 print(f"Status: {run.lifecycle_state}")
146 if run.lifecycle_state in ('FAILED', 'SUCCEEDED', 'CANCELED'):
147 break
148 handle_server_ask(client_channel, handlers)
150 if cli_mode and not cli_entered and run.lifecycle_state == 'IN_PROGRESS':
151 cli_entered = True
152 cli_handler = CLIHandler(client_channel, None, handlers)
153 cli_handler.loop()
156 if run.lifecycle_state in ('FAILED', 'CANCELED'):
157 raise Exception(f"Job failed with status: {run.lifecycle_state}")
158 return client_channel.read_json('result.json')
159 # return {
160 # 'state': run.lifecycle_state,
161 # 'run_id': run_id,
162 # 'succeeded': run.lifecycle_state == 'SUCCEEDED'
163 # }
167 def kill(self, run_id):
168 df_client = get_df_client(self.region)
170 r = df_client.delete_run(run_id)
171 check_response(r, lambda : SparkETLKillFailure("dataflow failed to kill the run"))