Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import json 

2from urllib.parse import urlparse 

3import time 

4import uuid 

5from datetime import datetime, timedelta 

6import os 

7from termcolor import colored, cprint 

8import readline 

9 

10import oci 

11from oci_core import get_os_client, get_df_client, os_upload, os_upload_json, os_download, os_download_json , \ 

12 os_has_object, os_delete_object 

13 

14from spark_etl.job_submitters import AbstractJobSubmitter 

15from spark_etl import SparkETLLaunchFailure, SparkETLKillFailure 

16from .tools import check_response, remote_execute 

17from spark_etl.utils import CLIHandler, handle_server_ask 

18from spark_etl.core import ClientChannelInterface 

19 

20class ClientChannel(ClientChannelInterface): 

21 def __init__(self, region, oci_config, run_base_dir, run_id): 

22 self.region = region 

23 self.oci_config = oci_config 

24 self.run_base_dir = run_base_dir 

25 self.run_id = run_id 

26 

27 o = urlparse(run_base_dir) 

28 self.namespace = o.netloc.split('@')[1] 

29 self.bucket = o.netloc.split('@')[0] 

30 self.root_path = o.path[1:] # remove the leading "/" 

31 

32 

33 def read_json(self, name): 

34 os_client = get_os_client(self.region, self.oci_config) 

35 object_name = os.path.join(self.root_path, self.run_id, name) 

36 result = os_download_json(os_client, self.namespace, self.bucket, object_name) 

37 return result 

38 

39 

40 def write_json(self, name, payload): 

41 os_client = get_os_client(self.region, self.oci_config) 

42 object_name = os.path.join(self.root_path, self.run_id, name) 

43 os_upload_json(os_client, payload, self.namespace, self.bucket, object_name) 

44 

45 

46 def has_json(self, name): 

47 os_client = get_os_client(self.region, self.oci_config) 

48 object_name = os.path.join(self.root_path, self.run_id, name) 

49 return os_has_object(os_client, self.namespace, self.bucket, object_name) 

50 

51 

52 def delete_json(self, name): 

53 os_client = get_os_client(self.region, self.oci_config) 

54 object_name = os.path.join(self.root_path, self.run_id, name) 

55 os_delete_object(os_client, self.namespace, self.bucket, object_name) 

56 

57 

58class DataflowJobSubmitter(AbstractJobSubmitter): 

59 def __init__(self, config): 

60 super(DataflowJobSubmitter, self).__init__(config) 

61 # config fields 

62 # region, e.g. IAD 

63 # run_base_dir, uri, point to the run directory. 

64 run_base_dir = self.config['run_base_dir'] 

65 o = urlparse(run_base_dir) 

66 if o.scheme != 'oci': 

67 raise SparkETLLaunchFailure("run_base_dir must be in OCI") 

68 

69 

70 @property 

71 def region(self): 

72 return self.config['region'] 

73 

74 

75 

76 def run(self, deployment_location, options={}, args={}, handlers=None, on_job_submitted=None, cli_mode=False): 

77 # options fields 

78 # num_executors : number 

79 # driver_shape : string 

80 # executor_shape : string 

81 # lib_url_duration : number (repre the number of minutes) 

82 # on_job_submitted : callback, on_job_submitted(run_id, vendor_info={'oci_run_id': 'xxxyyy'}) 

83 

84 o = urlparse(deployment_location) 

85 if o.scheme != 'oci': 

86 raise SparkETLLaunchFailure("deployment_location must be in OCI") 

87 

88 run_base_dir = self.config['run_base_dir'] 

89 run_id = str(uuid.uuid4()) 

90 

91 namespace = o.netloc.split('@')[1] 

92 bucket = o.netloc.split('@')[0] 

93 root_path = o.path[1:] # remove the leading "/" 

94 

95 # let's get the deployment.json 

96 os_client = get_os_client(self.region, self.config.get("oci_config")) 

97 deployment = os_download_json(os_client, namespace, bucket, os.path.join(root_path, "deployment.json")) 

98 

99 # let's upload the args 

100 client_channel = ClientChannel( 

101 self.region, 

102 self.config.get("oci_config"), 

103 run_base_dir, 

104 run_id 

105 ) 

106 client_channel.write_json("args.json", args) 

107 

108 o = urlparse(self.config['run_base_dir']) 

109 namespace = o.netloc.split('@')[1] 

110 bucket = o.netloc.split('@')[0] 

111 root_path = o.path[1:] # remove the leading "/" 

112 os_upload_json(os_client, args, namespace, bucket, f"{root_path}/{run_id}/args.json") 

113 

114 df_client = get_df_client(self.region, self.config.get("oci_config")) 

115 crd_argv = { 

116 'compartment_id': deployment['compartment_id'], 

117 'application_id': deployment['application_id'], 

118 'display_name' :options["display_name"], 

119 'arguments': [ 

120 "--deployment-location", deployment_location, 

121 "--run-id", run_id, 

122 "--run-dir", os.path.join(run_base_dir, run_id), 

123 "--app-region", self.region, 

124 ], 

125 } 

126 for key in ['num_executors', 'driver_shape', 'executor_shape']: 

127 if key in options: 

128 crd_argv[key] = options[key] 

129 

130 create_run_details = oci.data_flow.models.CreateRunDetails(**crd_argv) 

131 r = df_client.create_run(create_run_details=create_run_details) 

132 check_response(r, lambda : SparkETLLaunchFailure("dataflow failed to run the application")) 

133 run = r.data 

134 oci_run_id = run.id 

135 print(f"Job launched, run_id = {run_id}, oci_run_id = {run.id}") 

136 if on_job_submitted is not None: 

137 on_job_submitted(run_id, vendor_info={'oci_run_id': run.id}) 

138 

139 cli_entered = False 

140 while True: 

141 time.sleep(10) 

142 r = df_client.get_run(run_id=run.id) 

143 check_response(r, lambda : SparkETLGetStatusFailure("dataflow failed to get run status")) 

144 run = r.data 

145 print(f"Status: {run.lifecycle_state}") 

146 if run.lifecycle_state in ('FAILED', 'SUCCEEDED', 'CANCELED'): 

147 break 

148 handle_server_ask(client_channel, handlers) 

149 

150 if cli_mode and not cli_entered and run.lifecycle_state == 'IN_PROGRESS': 

151 cli_entered = True 

152 cli_handler = CLIHandler(client_channel, None, handlers) 

153 cli_handler.loop() 

154 

155 

156 if run.lifecycle_state in ('FAILED', 'CANCELED'): 

157 raise Exception(f"Job failed with status: {run.lifecycle_state}") 

158 return client_channel.read_json('result.json') 

159 # return { 

160 # 'state': run.lifecycle_state, 

161 # 'run_id': run_id, 

162 # 'succeeded': run.lifecycle_state == 'SUCCEEDED' 

163 # } 

164 

165 

166 

167 def kill(self, run_id): 

168 df_client = get_df_client(self.region) 

169 

170 r = df_client.delete_run(run_id) 

171 check_response(r, lambda : SparkETLKillFailure("dataflow failed to kill the run"))