mrosinski commited on
Commit
205fe98
·
1 Parent(s): a1842a2

application files

Browse files
Files changed (2) hide show
  1. app.py +55 -0
  2. test-model.ipynb +289 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+
5
+ article = '''<img src="https://corporateweb-v3-corporatewebv3damstrawebassetbuck-1lruglqypgb84.s3-ap-southeast-2.amazonaws.com/public/cta-2.jpg"/> '''
6
+
7
+ examples = [
8
+ [
9
+ '''
10
+ A truck narrowly missed a person on a bicycle when they were reversing out of the depot on Friday. \
11
+ It was early morning before the sun was up and the cyclist did not have a light. Fortunately the \
12
+ driver spotted the rider and braked heavily to avoid a collision.
13
+ '''],
14
+ [
15
+ '''
16
+ When making a coffee I noticed the cord to the coffee machine was frayed and tagged it out of service. Now I need to find a barista!'''],
17
+ [
18
+ '''
19
+ A worker was using a grinder in a confined space when he became dizzy from the fumes in the area and had to be helped out. \
20
+ The gas monitor he was using was found to be faulty and when the area was assessed with another monitor there was an \
21
+ unacceptably high level of CO2 in the area''']]
22
+
23
+
24
+ title = "Incident Prioritisation Tool"
25
+ description = "Triage new incidents based on a distilbert-uncased NLP model that has been fine tuned on descriptions of incidents \
26
+ that have been risk rated in the past"
27
+
28
+ pipe = pipeline("text-classification", model="mrosinski/autotrain-distilbert-risk-ranker-1593356256")
29
+
30
+ def predict(text):
31
+ # if len(text[0]) > 60:
32
+ preds = pipe(text)[0]
33
+ return preds["label"].title(), f'Confidence Score: {round(preds["score"]*100, 1)}%'
34
+ # else:
35
+ # return 'Invalid entry', 'Try adding more information to describe the incident'
36
+
37
+
38
+ gradio_ui = gr.Interface(
39
+ fn=predict,
40
+ title=title,
41
+ description=description,
42
+ inputs=[
43
+ gr.inputs.Textbox(lines=5, label="Paste some text here"),
44
+ ],
45
+ outputs=[
46
+ gr.outputs.Textbox(label="Label"),
47
+ gr.outputs.Textbox(label="Score"),
48
+ ],
49
+ examples=examples,
50
+ article=article
51
+
52
+ )
53
+
54
+ gradio_ui.launch(debug=True)
55
+
test-model.ipynb ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/Users/mattrosinski/git/transformers-demos\n"
13
+ ]
14
+ }
15
+ ],
16
+ "source": [
17
+ "!pwd"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 2,
23
+ "metadata": {},
24
+ "outputs": [
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "/Users/mattrosinski/mambaforge/bin/python\n"
30
+ ]
31
+ }
32
+ ],
33
+ "source": [
34
+ "!which python"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 3,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "import gradio as gr\n",
44
+ "from transformers import pipeline"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 4,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "pipe = pipeline(\"text-classification\", model=\"mrosinski/autotrain-distilbert-risk-ranker-1593356256\")"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 5,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "text = ['''\n",
63
+ "A truck narrowly missed a person on a bicycle when they were reversing out of the depot on Friday. \\\n",
64
+ " It was early morning before the sun was up and the cyclist did not have a light. Fortunately the \\\n",
65
+ " driver spotted the rider and braked heavily to avoid a collision.\n",
66
+ "''']"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 40,
72
+ "metadata": {},
73
+ "outputs": [
74
+ {
75
+ "data": {
76
+ "text/plain": [
77
+ "60"
78
+ ]
79
+ },
80
+ "execution_count": 40,
81
+ "metadata": {},
82
+ "output_type": "execute_result"
83
+ }
84
+ ],
85
+ "source": [
86
+ "len('A truck narrowly missed a person on a bicycle when they were')"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 39,
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "data": {
96
+ "text/plain": [
97
+ "277"
98
+ ]
99
+ },
100
+ "execution_count": 39,
101
+ "metadata": {},
102
+ "output_type": "execute_result"
103
+ }
104
+ ],
105
+ "source": [
106
+ "len(text[0])"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 11,
112
+ "metadata": {},
113
+ "outputs": [
114
+ {
115
+ "data": {
116
+ "text/plain": [
117
+ "{'label': 'high risk', 'score': 0.7180770635604858}"
118
+ ]
119
+ },
120
+ "execution_count": 11,
121
+ "metadata": {},
122
+ "output_type": "execute_result"
123
+ }
124
+ ],
125
+ "source": [
126
+ "preds = pipe(text)[0]\n",
127
+ "preds"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 15,
133
+ "metadata": {},
134
+ "outputs": [
135
+ {
136
+ "data": {
137
+ "text/plain": [
138
+ "0.7180770635604858"
139
+ ]
140
+ },
141
+ "execution_count": 15,
142
+ "metadata": {},
143
+ "output_type": "execute_result"
144
+ }
145
+ ],
146
+ "source": [
147
+ "score = preds['score']\n",
148
+ "score"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 17,
154
+ "metadata": {},
155
+ "outputs": [
156
+ {
157
+ "data": {
158
+ "text/plain": [
159
+ "float"
160
+ ]
161
+ },
162
+ "execution_count": 17,
163
+ "metadata": {},
164
+ "output_type": "execute_result"
165
+ }
166
+ ],
167
+ "source": [
168
+ "type(score)"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 29,
174
+ "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "data": {
178
+ "text/plain": [
179
+ "('High Risk', 'Confidence score: 71.8%')"
180
+ ]
181
+ },
182
+ "execution_count": 29,
183
+ "metadata": {},
184
+ "output_type": "execute_result"
185
+ }
186
+ ],
187
+ "source": [
188
+ "preds[\"label\"].title(), f'Confidence Score: {round(preds[\"score\"], 3)*100}%'"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 41,
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "\n",
198
+ "def predict(text):\n",
199
+ " if len(text) < 60:\n",
200
+ " return 'Invalid entry', 'Try adding more information to describe the incident'\n",
201
+ " preds = pipe(text)[0]\n",
202
+ " return preds[\"label\"].title(), f'Confidence Score: {round(preds[\"score\"], 3)*100}%'"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 50,
208
+ "metadata": {},
209
+ "outputs": [
210
+ {
211
+ "data": {
212
+ "text/plain": [
213
+ "tuple"
214
+ ]
215
+ },
216
+ "execution_count": 50,
217
+ "metadata": {},
218
+ "output_type": "execute_result"
219
+ }
220
+ ],
221
+ "source": [
222
+ "predict(text)"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": 47,
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "string = 'some text'"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": 49,
237
+ "metadata": {},
238
+ "outputs": [
239
+ {
240
+ "data": {
241
+ "text/plain": [
242
+ "tuple"
243
+ ]
244
+ },
245
+ "execution_count": 49,
246
+ "metadata": {},
247
+ "output_type": "execute_result"
248
+ }
249
+ ],
250
+ "source": [
251
+ "predict(string)"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": []
260
+ }
261
+ ],
262
+ "metadata": {
263
+ "kernelspec": {
264
+ "display_name": "Python 3.9.10",
265
+ "language": "python",
266
+ "name": "python3"
267
+ },
268
+ "language_info": {
269
+ "codemirror_mode": {
270
+ "name": "ipython",
271
+ "version": 3
272
+ },
273
+ "file_extension": ".py",
274
+ "mimetype": "text/x-python",
275
+ "name": "python",
276
+ "nbconvert_exporter": "python",
277
+ "pygments_lexer": "ipython3",
278
+ "version": "3.9.10"
279
+ },
280
+ "orig_nbformat": 4,
281
+ "vscode": {
282
+ "interpreter": {
283
+ "hash": "b29df53a373a75f04ac216b720f486bfd73e41a5a0018838dedd490de94cf09c"
284
+ }
285
+ }
286
+ },
287
+ "nbformat": 4,
288
+ "nbformat_minor": 2
289
+ }